Author: tommaso Date: Wed Oct 12 13:59:41 2016 New Revision: 1764472 URL: http://svn.apache.org/viewvc?rev=1764472&view=rev Log: refactored char/word/stacked RNNs in 2 classes
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java - copied, changed from r1764448, labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java - copied, changed from r1764448, labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java - copied, changed from r1764448, labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java Removed: labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (from r1764448, labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java&r1=1764448&r2=1764472&rev=1764472&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Wed Oct 12 13:59:41 2016 @@ -34,13 +34,13 @@ import java.util.Map; import java.util.Set; /** - * A min char-level vanilla RNN model, based on Andrej Karpathy's python code. + * A min char/word-level vanilla RNN model, based on Andrej Karpathy's python code. * See also: * * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness">The Unreasonable Effectiveness of Recurrent Neural Networks</a> * @see <a href="https://gist.github.com/karpathy/d4dee566867f8291f086">Minimal character-level language model with a Vanilla Recurrent Neural Network, in Python/numpy</a> */ -public class CharRNN { +public class RNN { // hyperparameters protected final float learningRate; // size of hidden layer of neurons @@ -62,11 +62,11 @@ public class CharRNN { private INDArray hPrev = null; // memory state - public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { + public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { this(learningRate, seqLength, hiddenLayerSize, epochs, text, true); } - public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { + public RNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { this.learningRate = learningRate; this.seqLength = seqLength; this.hiddenLayerSize = hiddenLayerSize; @@ -312,6 +312,9 @@ public class CharRNN { NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape()); while (ndIndexIterator.hasNext()) { int[] next = ndIndexIterator.next(); + if (!useChars && txt.length() > 0) { + txt.append(' '); + } txt.append(ixToChar.get(ixes.getInt(next))); } return txt.toString(); @@ -323,7 +326,7 @@ public class CharRNN { @Override public String toString() { - return "CharRNN{" + + return "RNN{" + "learningRate=" + learningRate + ", seqLength=" + seqLength + ", hiddenLayerSize=" + hiddenLayerSize + @@ -335,8 +338,8 @@ public class CharRNN { public String getHyperparamsString() { - return "CharRNN{" + - ", wxh=" + wxh + + return "RNN{" + + "wxh=" + wxh + ", whh=" + whh + ", why=" + why + ", bh=" + bh + Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (from r1764448, labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java&r1=1764448&r2=1764472&rev=1764472&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Wed Oct 12 13:59:41 2016 @@ -30,7 +30,7 @@ import java.util.LinkedList; import java.util.List; /** - * A basic char-level stacked RNN model (2 hidden recurrent layers), based on Stacked RNN architecture from ICLR 2014's + * A basic char/word-level stacked RNN model (2 hidden recurrent layers), based on Stacked RNN architecture from ICLR 2014's * "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar Gulcehre, Kyunghyun Cho and Yoshua Bengio * and Andrej Karpathy's notes on RNNs. * See also: @@ -38,7 +38,7 @@ import java.util.List; * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness">The Unreasonable Effectiveness of Recurrent Neural Networks</a> * @see <a href="https://arxiv.org/abs/1312.6026">How to Construct Deep Recurrent Neural Networks</a> */ -public class CharStackedRNN extends CharRNN { +public class StackedRNN extends RNN { // model parameters private final INDArray wxh; // input to hidden @@ -52,11 +52,11 @@ public class CharStackedRNN extends Char private INDArray hPrev = null; // memory state private INDArray hPrev2 = null; // memory state - public CharStackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { + public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { this(learningRate, seqLength, hiddenLayerSize, epochs, text, true); } - public CharStackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { + public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01); Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java (from r1764448, labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java&r1=1764448&r2=1764472&rev=1764472&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Wed Oct 12 13:59:41 2016 @@ -30,17 +30,17 @@ import java.util.List; import java.util.Random; /** - * CV tests for {@link CharRNN} + * CV tests for {@link RNN} */ @RunWith(Parameterized.class) -public class CharRNNCrossValidationTest { +public class RNNCrossValidationTest { private float learningRate; private int seqLength; private int hiddenLayerSize; private Random r = new Random(); - public CharRNNCrossValidationTest(float learningRate, int seqLength, int hiddenLayerSize) { + public RNNCrossValidationTest(float learningRate, int seqLength, int hiddenLayerSize) { this.learningRate = learningRate; this.seqLength = seqLength; this.hiddenLayerSize = hiddenLayerSize; @@ -49,58 +49,56 @@ public class CharRNNCrossValidationTest @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][]{ -// {1e-1f, 100, 25}, {1e-1f, 200, 512}, {5e-1f, 25, 25}, {1e-1f, 250, 512}, -// {5e-1f, 25, 100}, {5e-1f, 200, 50}, {5e-1f, 200, 40}, {5e-1f, 100, 30}, {5e-1f, 100, 20}, {5e-1f, 250, 20}, {5e-1f, 250, 15}, -// {5e-2f, 50, 64}, {3e-2f, 50, 128}, {5e-2f, 100, 128}, {5e-2f, 100, 256}, {5e-2f, 100, 512}, {5e-2f, 100, 128}, -// {5e-3f, 100, 256}, {5e-3f, 100, 512}, {5e-4f, 100, 128}, {5e-4f, 100, 256}, -// {5e-3f, 100, 100}, {5e-2f, 50, 100} - {4e-1f, 100, 10} + {1e-1f, 25, 100}, + {3e-1f, 25, 100}, + {3e-1f, 100, 25}, + {1e-1f, 100, 25}, }); } @Test - public void testStackedCharRNNLearn() throws Exception { + public void testVanillaCharRNNLearn() throws Exception { InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); String text = IOUtils.toString(resourceAsStream); - int epochs = 100; - CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); - checkCorrectWordsRatio(text, charRNN); + int epochs = 10; + RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + checkCorrectWordsRatio(text, RNN); } @Test - public void testStackedWordRNNLearn() throws Exception { + public void testVanillaWordRNNLearn() throws Exception { InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); String text = IOUtils.toString(resourceAsStream); int epochs = 100; - CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); - checkCorrectWordsRatio(text, charRNN); + RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); + checkCorrectWordsRatio(text, RNN); } @Test - public void testVanillaWordRNNLearn() throws Exception { + public void testStackedCharRNNLearn() throws Exception { InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); String text = IOUtils.toString(resourceAsStream); int epochs = 100; - CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); - checkCorrectWordsRatio(text, charRNN); + RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + checkCorrectWordsRatio(text, RNN); } @Test - public void testVanillaCharRNNLearn() throws Exception { + public void testStackedWordRNNLearn() throws Exception { InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); String text = IOUtils.toString(resourceAsStream); int epochs = 100; - CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); - checkCorrectWordsRatio(text, charRNN); + RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); + checkCorrectWordsRatio(text, RNN); } - private void checkCorrectWordsRatio(String text, CharRNN charRNN) { - System.out.println(charRNN); + private void checkCorrectWordsRatio(String text, RNN RNN) { + System.out.println(RNN); List<String> words = Arrays.asList(text.split(" ")); - charRNN.learn(); + RNN.learn(); for (int i = 0; i < 10; i++) { double c = 0; - String sample = charRNN.sample(r.nextInt(charRNN.getVocabSize())); + String sample = RNN.sample(r.nextInt(RNN.getVocabSize())); String[] sampleWords = sample.split(" "); for (String sw : sampleWords) { if (words.contains(sw)) { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org