Author: tommaso
Date: Wed Oct 12 13:59:41 2016
New Revision: 1764472

URL: http://svn.apache.org/viewvc?rev=1764472&view=rev
Log:
refactored char/word/stacked RNNs in 2 classes

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
      - copied, changed from r1764448, 
labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
      - copied, changed from r1764448, 
labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java
      - copied, changed from r1764448, 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
Removed:
    labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java

Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (from 
r1764448, labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java)
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java&r1=1764448&r2=1764472&rev=1764472&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Wed Oct 12 
13:59:41 2016
@@ -34,13 +34,13 @@ import java.util.Map;
 import java.util.Set;
 
 /**
- * A min char-level vanilla RNN model, based on Andrej Karpathy's python code.
+ * A min char/word-level vanilla RNN model, based on Andrej Karpathy's python 
code.
  * See also:
  *
  * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness";>The 
Unreasonable Effectiveness of Recurrent Neural Networks</a>
  * @see <a 
href="https://gist.github.com/karpathy/d4dee566867f8291f086";>Minimal 
character-level language model with a Vanilla Recurrent Neural Network, in 
Python/numpy</a>
  */
-public class CharRNN {
+public class RNN {
 
   // hyperparameters
   protected final float learningRate; // size of hidden layer of neurons
@@ -62,11 +62,11 @@ public class CharRNN {
 
   private INDArray hPrev = null; // memory state
 
-  public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text) {
+  public RNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text) {
     this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
   }
 
-  public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text, boolean useChars) {
+  public RNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text, boolean useChars) {
     this.learningRate = learningRate;
     this.seqLength = seqLength;
     this.hiddenLayerSize = hiddenLayerSize;
@@ -312,6 +312,9 @@ public class CharRNN {
     NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
     while (ndIndexIterator.hasNext()) {
       int[] next = ndIndexIterator.next();
+      if (!useChars && txt.length() > 0) {
+        txt.append(' ');
+      }
       txt.append(ixToChar.get(ixes.getInt(next)));
     }
     return txt.toString();
@@ -323,7 +326,7 @@ public class CharRNN {
 
   @Override
   public String toString() {
-    return "CharRNN{" +
+    return "RNN{" +
             "learningRate=" + learningRate +
             ", seqLength=" + seqLength +
             ", hiddenLayerSize=" + hiddenLayerSize +
@@ -335,8 +338,8 @@ public class CharRNN {
 
 
   public String getHyperparamsString() {
-    return "CharRNN{" +
-            ", wxh=" + wxh +
+    return "RNN{" +
+            "wxh=" + wxh +
             ", whh=" + whh +
             ", why=" + why +
             ", bh=" + bh +

Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (from 
r1764448, labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java)
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java&r1=1764448&r2=1764472&rev=1764472&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Wed Oct 12 
13:59:41 2016
@@ -30,7 +30,7 @@ import java.util.LinkedList;
 import java.util.List;
 
 /**
- * A basic char-level stacked RNN model (2 hidden recurrent layers), based on 
Stacked RNN architecture from ICLR 2014's
+ * A basic char/word-level stacked RNN model (2 hidden recurrent layers), 
based on Stacked RNN architecture from ICLR 2014's
  * "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar 
Gulcehre, Kyunghyun Cho and Yoshua Bengio
  * and Andrej Karpathy's notes on RNNs.
  * See also:
@@ -38,7 +38,7 @@ import java.util.List;
  * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness";>The 
Unreasonable Effectiveness of Recurrent Neural Networks</a>
  * @see <a href="https://arxiv.org/abs/1312.6026";>How to Construct Deep 
Recurrent Neural Networks</a>
  */
-public class CharStackedRNN extends CharRNN {
+public class StackedRNN extends RNN {
 
   // model parameters
   private final INDArray wxh; // input to hidden
@@ -52,11 +52,11 @@ public class CharStackedRNN extends Char
   private INDArray hPrev = null; // memory state
   private INDArray hPrev2 = null; // memory state
 
-  public CharStackedRNN(float learningRate, int seqLength, int 
hiddenLayerSize, int epochs, String text) {
+  public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, 
int epochs, String text) {
     this(learningRate, seqLength, hiddenLayerSize, epochs, text, true);
   }
 
-  public CharStackedRNN(float learningRate, int seqLength, int 
hiddenLayerSize, int epochs, String text, boolean useChars) {
+  public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, 
int epochs, String text, boolean useChars) {
     super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars);
 
     wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01);

Copied: 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java 
(from r1764448, 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java)
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java&r1=1764448&r2=1764472&rev=1764472&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java 
Wed Oct 12 13:59:41 2016
@@ -30,17 +30,17 @@ import java.util.List;
 import java.util.Random;
 
 /**
- * CV tests for {@link CharRNN}
+ * CV tests for {@link RNN}
  */
 @RunWith(Parameterized.class)
-public class CharRNNCrossValidationTest {
+public class RNNCrossValidationTest {
 
   private float learningRate;
   private int seqLength;
   private int hiddenLayerSize;
   private Random r = new Random();
 
-  public CharRNNCrossValidationTest(float learningRate, int seqLength, int 
hiddenLayerSize) {
+  public RNNCrossValidationTest(float learningRate, int seqLength, int 
hiddenLayerSize) {
     this.learningRate = learningRate;
     this.seqLength = seqLength;
     this.hiddenLayerSize = hiddenLayerSize;
@@ -49,58 +49,56 @@ public class CharRNNCrossValidationTest
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][]{
-//            {1e-1f, 100, 25}, {1e-1f, 200, 512}, {5e-1f, 25, 25}, {1e-1f, 
250, 512},
-//            {5e-1f, 25, 100}, {5e-1f, 200, 50}, {5e-1f, 200, 40},  {5e-1f, 
100, 30}, {5e-1f, 100, 20}, {5e-1f, 250, 20}, {5e-1f, 250, 15},
-//            {5e-2f, 50, 64}, {3e-2f, 50, 128}, {5e-2f, 100, 128}, {5e-2f, 
100, 256}, {5e-2f, 100, 512}, {5e-2f, 100, 128},
-//            {5e-3f, 100, 256}, {5e-3f, 100, 512}, {5e-4f, 100, 128}, {5e-4f, 
100, 256},
-//            {5e-3f, 100, 100}, {5e-2f, 50, 100}
-            {4e-1f, 100, 10}
+            {1e-1f, 25, 100},
+            {3e-1f, 25, 100},
+            {3e-1f, 100, 25},
+            {1e-1f, 100, 25},
     });
   }
 
   @Test
-  public void testStackedCharRNNLearn() throws Exception {
+  public void testVanillaCharRNNLearn() throws Exception {
     InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
     String text = IOUtils.toString(resourceAsStream);
-    int epochs = 100;
-    CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, 
hiddenLayerSize, epochs, text);
-    checkCorrectWordsRatio(text, charRNN);
+    int epochs = 10;
+    RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
+    checkCorrectWordsRatio(text, RNN);
   }
 
   @Test
-  public void testStackedWordRNNLearn() throws Exception {
+  public void testVanillaWordRNNLearn() throws Exception {
     InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
     String text = IOUtils.toString(resourceAsStream);
     int epochs = 100;
-    CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, 
hiddenLayerSize, epochs, text, false);
-    checkCorrectWordsRatio(text, charRNN);
+    RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 
false);
+    checkCorrectWordsRatio(text, RNN);
   }
 
   @Test
-  public void testVanillaWordRNNLearn() throws Exception {
+  public void testStackedCharRNNLearn() throws Exception {
     InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
     String text = IOUtils.toString(resourceAsStream);
     int epochs = 100;
-    CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text, false);
-    checkCorrectWordsRatio(text, charRNN);
+    RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text);
+    checkCorrectWordsRatio(text, RNN);
   }
 
   @Test
-  public void testVanillaCharRNNLearn() throws Exception {
+  public void testStackedWordRNNLearn() throws Exception {
     InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
     String text = IOUtils.toString(resourceAsStream);
     int epochs = 100;
-    CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text);
-    checkCorrectWordsRatio(text, charRNN);
+    RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text, false);
+    checkCorrectWordsRatio(text, RNN);
   }
 
-  private void checkCorrectWordsRatio(String text, CharRNN charRNN) {
-    System.out.println(charRNN);
+  private void checkCorrectWordsRatio(String text, RNN RNN) {
+    System.out.println(RNN);
     List<String> words = Arrays.asList(text.split(" "));
-    charRNN.learn();
+    RNN.learn();
     for (int i = 0; i < 10; i++) {
       double c = 0;
-      String sample = charRNN.sample(r.nextInt(charRNN.getVocabSize()));
+      String sample = RNN.sample(r.nextInt(RNN.getVocabSize()));
       String[] sampleWords = sample.split(" ");
       for (String sw : sampleWords) {
         if (words.contains(sw)) {



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to