Author: tommaso
Date: Wed Oct 12 12:36:05 2016
New Revision: 1764448

URL: http://svn.apache.org/viewvc?rev=1764448&view=rev
Log:
added sRNN

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java   
(with props)
Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java?rev=1764448&r1=1764447&r2=1764448&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java Wed Oct 12 
12:36:05 2016
@@ -43,14 +43,15 @@ import java.util.Set;
 public class CharRNN {
 
   // hyperparameters
-  private final float learningRate; // size of hidden layer of neurons
-  private final int seqLength; // no. of steps to unroll the RNN for
-  private final int hiddenLayerSize;
-  private final int epochs;
-  private final int vocabSize;
-  private final Map<Character, Integer> charToIx;
-  private final Map<Integer, Character> ixToChar;
-  private final List<Character> data;
+  protected final float learningRate; // size of hidden layer of neurons
+  protected final int seqLength; // no. of steps to unroll the RNN for
+  protected final int hiddenLayerSize;
+  protected final int epochs;
+  private final boolean useChars;
+  protected final int vocabSize;
+  protected final Map<String, Integer> charToIx;
+  protected final Map<Integer, String> ixToChar;
+  protected final List<String> data;
 
   // model parameters
   private final INDArray wxh; // input to hidden
@@ -62,22 +63,29 @@ public class CharRNN {
   private INDArray hPrev = null; // memory state
 
   public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text) {
+    this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
+  }
+
+  public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text, boolean useChars) {
     this.learningRate = learningRate;
     this.seqLength = seqLength;
     this.hiddenLayerSize = hiddenLayerSize;
     this.epochs = epochs;
-    char[] textChars = text.toCharArray();
+    this.useChars = useChars;
+
+    String[] textTokens = useChars ? toStrings(text.toCharArray()) : 
text.split(" ");
     data = new LinkedList<>();
-    for (char c : textChars) {
+    for (String c : textTokens) {
       data.add(c);
     }
-    Set<Character> chars = new HashSet<>(data);
-    vocabSize = chars.size();
-    System.out.printf("data has %d characters, %d unique.", data.size(), 
vocabSize);
+    Set<String> tokens = new HashSet<>(data);
+    vocabSize = tokens.size();
+
+    System.out.printf("data has %d tokens, %d unique.", data.size(), 
vocabSize);
     charToIx = new HashMap<>();
     ixToChar = new HashMap<>();
     int i = 0;
-    for (Character c : chars) {
+    for (String c : tokens) {
       charToIx.put(c, i);
       ixToChar.put(i, c);
       i++;
@@ -90,6 +98,14 @@ public class CharRNN {
     by = Nd4j.zeros(vocabSize, 1).mul(0.01);
   }
 
+  private String[] toStrings(char[] chars) {
+    String[] strings = new String[chars.length];
+    for (int i = 0; i < chars.length; i++) {
+      strings[i] = String.valueOf(chars[i]);
+    }
+    return strings;
+  }
+
   public void learn() {
 
     int currentEpoch = 0;
@@ -137,7 +153,7 @@ public class CharRNN {
       INDArray dby = Nd4j.zerosLike(by);
 
       // forward seqLength characters through the net and fetch gradient
-      double loss = lossFun(vocabSize, wxh, whh, why, bh, by, hPrev, inputs, 
targets, dWxh, dWhh, dWhy, dbh, dby);
+      double loss = lossFun(inputs, targets, dWxh, dWhh, dWhy, dbh, dby);
       smoothLoss = smoothLoss * 0.99 + loss * 0.001;
       if (Double.isNaN(smoothLoss)) {
         System.out.println("loss is NaN (over/underflow occured, try adjusting 
hyperparameters)");
@@ -168,10 +184,10 @@ public class CharRNN {
     }
   }
 
-  private INDArray getSequence(int p) {
+  protected INDArray getSequence(int p) {
     INDArray inputs = Nd4j.create(seqLength);
     int c = 0;
-    for (Character ch : data.subList(p, p + seqLength)) {
+    for (String ch : data.subList(p, p + seqLength)) {
       Integer ix = charToIx.get(ch);
       inputs.putScalar(c, ix);
       c++;
@@ -184,8 +200,7 @@ public class CharRNN {
    * hprev is Hx1 array of initial hidden state
    * returns the loss, gradients on model parameters and last hidden state
    */
-  private double lossFun(int vocabSize, INDArray wxh, INDArray whh, INDArray 
why, INDArray bh, INDArray by, INDArray hPrev,
-                         INDArray inputs, INDArray targets, INDArray dWxh, 
INDArray dWhh, INDArray dWhy, INDArray dbh,
+  private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, 
INDArray dWhh, INDArray dWhy, INDArray dbh,
                          INDArray dby) {
 
     INDArray xs = Nd4j.zeros(inputs.length(), vocabSize);
@@ -253,7 +268,7 @@ public class CharRNN {
     return loss;
   }
 
-  private INDArray init(int t, INDArray ast) {
+  protected INDArray init(int t, INDArray ast) {
     INDArray as;
     int[] aShape = ast.shape();
     int[] shape = new int[1 + aShape.length];
@@ -306,4 +321,26 @@ public class CharRNN {
     return vocabSize;
   }
 
+  @Override
+  public String toString() {
+    return "CharRNN{" +
+            "learningRate=" + learningRate +
+            ", seqLength=" + seqLength +
+            ", hiddenLayerSize=" + hiddenLayerSize +
+            ", epochs=" + epochs +
+            ", vocabSize=" + vocabSize +
+            ", useChars=" + useChars +
+            '}';
+  }
+
+
+  public String getHyperparamsString() {
+    return "CharRNN{" +
+            ", wxh=" + wxh +
+            ", whh=" + whh +
+            ", why=" + why +
+            ", bh=" + bh +
+            ", by=" + by +
+            '}';
+  }
 }

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java?rev=1764448&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java Wed 
Oct 12 12:36:05 2016
@@ -0,0 +1,296 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.distribution.EnumeratedDistribution;
+import org.apache.commons.math3.util.Pair;
+import org.nd4j.linalg.api.iter.NdIndexIterator;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
+import java.util.LinkedList;
+import java.util.List;
+
+/**
+ * A basic char-level stacked RNN model (2 hidden recurrent layers), based on 
Stacked RNN architecture from ICLR 2014's
+ * "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar 
Gulcehre, Kyunghyun Cho and Yoshua Bengio
+ * and Andrej Karpathy's notes on RNNs.
+ * See also:
+ *
+ * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness";>The 
Unreasonable Effectiveness of Recurrent Neural Networks</a>
+ * @see <a href="https://arxiv.org/abs/1312.6026";>How to Construct Deep 
Recurrent Neural Networks</a>
+ */
+public class CharStackedRNN extends CharRNN {
+
+  // model parameters
+  private final INDArray wxh; // input to hidden
+  private final INDArray whh; // hidden to hidden
+  private final INDArray whh2; // hidden to hidden2
+  private final INDArray wh2y; // hidden2 to output
+  private final INDArray bh; // hidden bias
+  private final INDArray bh2; // hidden2 bias
+  private final INDArray by; // output bias
+
+  private INDArray hPrev = null; // memory state
+  private INDArray hPrev2 = null; // memory state
+
+  public CharStackedRNN(float learningRate, int seqLength, int 
hiddenLayerSize, int epochs, String text) {
+    this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
+  }
+
+  public CharStackedRNN(float learningRate, int seqLength, int 
hiddenLayerSize, int epochs, String text, boolean useChars) {
+    super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars);
+
+    wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01);
+    whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01);
+    whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01);
+    wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).mul(0.01);
+    bh = Nd4j.zeros(hiddenLayerSize, 1).mul(0.01);
+    bh2 = Nd4j.zeros(hiddenLayerSize, 1).mul(0.01);
+    by = Nd4j.zeros(vocabSize, 1).mul(0.01);
+  }
+
+  public void learn() {
+
+    int currentEpoch = 0;
+
+    int n = 0;
+    int p = 0;
+
+    // memory variables for Adagrad
+    INDArray mWxh = Nd4j.zerosLike(wxh);
+    INDArray mWhh = Nd4j.zerosLike(whh);
+    INDArray mWhh2 = Nd4j.zerosLike(whh2);
+    INDArray mWh2y = Nd4j.zerosLike(wh2y);
+
+    INDArray mbh = Nd4j.zerosLike(bh);
+    INDArray mbh2 = Nd4j.zerosLike(bh2);
+    INDArray mby = Nd4j.zerosLike(by);
+
+    // loss at iteration 0
+    double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;
+
+    while (true) {
+      // prepare inputs (we're sweeping from left to right in steps seqLength 
long)
+      if (p + seqLength + 1 >= data.size() || n == 0) {
+        hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
+        hPrev2 = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
+        p = 0; // go from start of data
+        currentEpoch++;
+        if (currentEpoch == epochs) {
+          System.out.println("training finished: e:" + epochs + ", l: " + 
smoothLoss + ", h:(" + learningRate + ", " + seqLength + ", " + hiddenLayerSize 
+ ")");
+          break;
+        }
+      }
+
+      INDArray inputs = getSequence(p);
+      INDArray targets = getSequence(p + 1);
+
+      // sample from the model every now and then
+      if (n % 1000 == 0) {
+        String txt = sample(inputs.getInt(0));
+        System.out.printf("\n---\n %s \n----\n", txt);
+      }
+
+      INDArray dWxh = Nd4j.zerosLike(wxh);
+      INDArray dWhh = Nd4j.zerosLike(whh);
+      INDArray dWhh2 = Nd4j.zerosLike(whh2);
+      INDArray dWh2y = Nd4j.zerosLike(wh2y);
+
+      INDArray dbh = Nd4j.zerosLike(bh);
+      INDArray dbh2 = Nd4j.zerosLike(bh);
+      INDArray dby = Nd4j.zerosLike(by);
+
+      // forward seqLength characters through the net and fetch gradient
+      double loss = lossFun(inputs, targets, dWxh, dWhh, dWhh2, dWh2y, dbh, 
dbh2, dby);
+      smoothLoss = smoothLoss * 0.99 + loss * 0.001;
+      if (Double.isNaN(smoothLoss)) {
+        System.out.println("loss is NaN (over/underflow occured, try adjusting 
hyperparameters)");
+        break;
+      }
+      if (n % 1000 == 0) {
+        System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print 
progress
+      }
+
+      // perform parameter update with Adagrad
+      mWxh.addi(dWxh.mul(dWxh));
+      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8))));
+
+      mWhh.addi(dWhh.mul(dWhh));
+      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8))));
+
+      mWhh2.addi(dWhh2.mul(dWhh2));
+      
whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.addi(1e-8))));
+
+      mbh2.addi(dbh2.mul(dbh2));
+      bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.addi(1e-8))));
+
+      mWh2y.addi(dWh2y.mul(dWh2y));
+      
wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.addi(1e-8))));
+
+      mbh.addi(dbh.mul(dbh));
+      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8))));
+
+      mby.addi(dby.mul(dby));
+      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8))));
+
+      p += seqLength; // move data pointer
+      n++; // iteration counter
+    }
+  }
+
+  /**
+   * inputs, targets are both list of integers
+   * hprev is Hx1 array of initial hidden state
+   * returns the loss, gradients on model parameters and last hidden state
+   */
+  private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, 
INDArray dWhh, INDArray dWhh2, INDArray dWh2y,
+                         INDArray dbh, INDArray dbh2, INDArray dby) {
+
+    INDArray xs = Nd4j.zeros(inputs.length(), vocabSize);
+    INDArray hs = null;
+    INDArray hs2 = null;
+    INDArray ys = null;
+    INDArray ps = null;
+
+    INDArray hs1 = Nd4j.create(hPrev.shape());
+    Nd4j.copy(hPrev, hs1);
+
+    INDArray hs12 = Nd4j.create(hPrev2.shape());
+    Nd4j.copy(hPrev2, hs1);
+
+    double loss = 0;
+
+    // forward pass
+    for (int t = 0; t < inputs.length(); t++) {
+      int tIndex = inputs.getScalar(t).getInt(0);
+      xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation
+
+      INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
+      INDArray hst = 
Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh)));
 // hidden state
+      if (hs == null) {
+        hs = init(inputs.length(), hst);
+      }
+      hs.putRow(t, hst);
+
+      INDArray hs2Row = t == 0 ? hs12 : hs2.getRow(t - 1);
+      INDArray hst2 = 
Transforms.tanh((whh.mmul(hs.getRow(t))).add((whh2.mmul(hs2Row)).add(bh2))); // 
hidden state 2
+      if (hs2 == null) {
+        hs2 = init(inputs.length(), hst2);
+      }
+      hs.putRow(t, hst);
+
+      INDArray yst = (wh2y.mmul(hst)).add(by); // unnormalized log 
probabilities for next chars
+      if (ys == null) {
+        ys = init(inputs.length(), yst);
+      }
+      ys.putRow(t, yst);
+      INDArray exp = Transforms.exp(yst);
+      Number sumExp = exp.sumNumber();
+      INDArray pst = exp.div(sumExp); // probabilities for next chars
+      if (ps == null) {
+        ps = init(inputs.length(), pst);
+      }
+      ps.putRow(t, pst);
+      loss += -Transforms.log(ps.getRow(t).getRow(targets.getInt(t)), 
true).sumNumber().doubleValue(); // softmax (cross-entropy loss)
+    }
+
+    this.hPrev = hs.getRow(inputs.length() - 1);
+
+    // backward pass: compute gradients going backwards
+    INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
+    INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0));
+    for (int t = inputs.length() - 1; t >= 0; t--) {
+      INDArray dy = ps.getRow(t).dup();
+      dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // 
backprop into y
+
+      INDArray hst2 = hs2.getRow(t);
+      dWh2y.addi(dy.mmul(hst2.transpose()));
+      dby.addi(dy);
+      INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // backprop into 
h2
+      INDArray dhraw2 = 
(Nd4j.ones(hst2.shape()).sub(hst2).mul(hst2)).mul(dh2); // backprop through 
tanh nonlinearity
+      dbh2.addi(dhraw2);
+
+      INDArray hst = hs.getRow(t);
+      dWhh2.addi(dh2.mmul(hst.transpose()));
+      dbh.addi(dh2);
+      INDArray dh = whh2.transpose().mmul(dh2).add(dhNext); // backprop into h
+      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // 
backprop through tanh nonlinearity
+      dbh.addi(dhraw);
+
+      dWxh.addi(dhraw.mmul(xs.getRow(t)));
+      INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
+      dWhh.addi(dhraw.mmul(hsRow.transpose()));
+      dhNext = whh.transpose().mmul(dhraw);
+
+    }
+    // clip exploding gradients
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh2, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWh2y, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -5, 5));
+
+    return loss;
+  }
+
+  /**
+   * sample a sequence of integers from the model, using current (hPrev) 
memory state, seedIx is seed letter for first time step
+   */
+  public String sample(int seedIx) {
+
+    INDArray x = Nd4j.zeros(vocabSize, 1);
+    x.putScalar(seedIx, 1);
+    int sampleSize = 200;
+    INDArray ixes = Nd4j.create(sampleSize);
+
+    for (int t = 0; t < sampleSize; t++) {
+      INDArray h = 
Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh)));
+      INDArray h2 = 
Transforms.tanh((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2)));
+      INDArray y = (wh2y.mmul(h2)).add(by);
+      INDArray exp = Transforms.exp(y);
+      INDArray pm = exp.div(Nd4j.sum(exp)).ravel();
+
+      List<Pair<Integer, Double>> d = new LinkedList<>();
+      for (int pi = 0; pi < vocabSize; pi++) {
+        d.add(new Pair<>(pi, pm.getDouble(0, pi)));
+      }
+      EnumeratedDistribution<Integer> distribution = new 
EnumeratedDistribution<>(d);
+
+      int ix = distribution.sample();
+
+      x = Nd4j.zeros(vocabSize, 1);
+      x.putScalar(ix, 1);
+      ixes.putScalar(t, ix);
+    }
+
+    StringBuilder txt = new StringBuilder();
+
+    NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
+    while (ndIndexIterator.hasNext()) {
+      int[] next = ndIndexIterator.next();
+      txt.append(ixToChar.get(ixes.getInt(next)));
+    }
+    return txt.toString();
+  }
+
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java?rev=1764448&r1=1764447&r2=1764448&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
 Wed Oct 12 12:36:05 2016
@@ -49,21 +49,53 @@ public class CharRNNCrossValidationTest
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][]{
-            {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 250, 
512},
-            {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 100, 
30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15},
-            {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 100, 
256}, {1e-2f, 100, 512}, {1e-2f, 100, 128},
-            {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 
100, 256},
-            {1e-3f, 100, 100},
+//            {1e-1f, 100, 25}, {1e-1f, 200, 512}, {5e-1f, 25, 25}, {1e-1f, 
250, 512},
+//            {5e-1f, 25, 100}, {5e-1f, 200, 50}, {5e-1f, 200, 40},  {5e-1f, 
100, 30}, {5e-1f, 100, 20}, {5e-1f, 250, 20}, {5e-1f, 250, 15},
+//            {5e-2f, 50, 64}, {3e-2f, 50, 128}, {5e-2f, 100, 128}, {5e-2f, 
100, 256}, {5e-2f, 100, 512}, {5e-2f, 100, 128},
+//            {5e-3f, 100, 256}, {5e-3f, 100, 512}, {5e-4f, 100, 128}, {5e-4f, 
100, 256},
+//            {5e-3f, 100, 100}, {5e-2f, 50, 100}
+            {4e-1f, 100, 10}
     });
   }
 
   @Test
-  public void testLearnWithDifferentHyperparameters() throws Exception {
-    System.out.println("hyperparameters: " + learningRate + ", " + seqLength + 
", " + hiddenLayerSize);
-    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/abstracts.txt");
+  public void testStackedCharRNNLearn() throws Exception {
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
     String text = IOUtils.toString(resourceAsStream);
-    int epochs = 1000000;
+    int epochs = 100;
+    CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, 
hiddenLayerSize, epochs, text);
+    checkCorrectWordsRatio(text, charRNN);
+  }
+
+  @Test
+  public void testStackedWordRNNLearn() throws Exception {
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
+    String text = IOUtils.toString(resourceAsStream);
+    int epochs = 100;
+    CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, 
hiddenLayerSize, epochs, text, false);
+    checkCorrectWordsRatio(text, charRNN);
+  }
+
+  @Test
+  public void testVanillaWordRNNLearn() throws Exception {
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
+    String text = IOUtils.toString(resourceAsStream);
+    int epochs = 100;
+    CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text, false);
+    checkCorrectWordsRatio(text, charRNN);
+  }
+
+  @Test
+  public void testVanillaCharRNNLearn() throws Exception {
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
+    String text = IOUtils.toString(resourceAsStream);
+    int epochs = 100;
     CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text);
+    checkCorrectWordsRatio(text, charRNN);
+  }
+
+  private void checkCorrectWordsRatio(String text, CharRNN charRNN) {
+    System.out.println(charRNN);
     List<String> words = Arrays.asList(text.split(" "));
     charRNN.learn();
     for (int i = 0; i < 10; i++) {
@@ -79,6 +111,7 @@ public class CharRNNCrossValidationTest
         c /= sample.length();
       }
       System.out.println("correct word ratio: " + c);
+
     }
   }
 

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java?rev=1764448&r1=1764447&r2=1764448&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
 Wed Oct 12 12:36:05 2016
@@ -49,11 +49,12 @@ public class WordRNNCrossValidationTest
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][]{
-            {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 250, 
512},
-            {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 100, 
30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15},
-            {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 100, 
256}, {1e-2f, 100, 512}, {1e-2f, 100, 128},
-            {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 
100, 256},
-            {2e-1f, 25, 100},
+//            {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 
250, 512},
+//            {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 
100, 30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15},
+//            {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 
100, 256}, {1e-2f, 100, 512}, {1e-2f, 100, 128},
+//            {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 
100, 256},
+//            {2e-1f, 25, 100},
+            {5e-2f, 50, 64}
     });
   }
 



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to