Author: tommaso
Date: Thu Oct  6 12:58:22 2016
New Revision: 1763584

URL: http://svn.apache.org/viewvc?rev=1763584&view=rev
Log:
char rnn refactoring, word rnn, added cv tests with different hyperparams

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java
      - copied, changed from r1763464, 
labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java   (with props)
    
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
      - copied, changed from r1763464, 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
   (with props)
Removed:
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java

Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java (from 
r1763464, labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java)
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java&r1=1763464&r2=1763584&rev=1763584&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java Thu Oct  6 
12:58:22 2016
@@ -40,20 +40,42 @@ import java.util.Set;
  * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness";>The 
Unreasonable Effectiveness of Recurrent Neural Networks</a>
  * @see <a 
href="https://gist.github.com/karpathy/d4dee566867f8291f086";>Minimal 
character-level language model with a Vanilla Recurrent Neural Network, in 
Python/numpy</a>
  */
-public class RNN {
-
-  public void learn(String text) {
+public class CharRNN {
 
+  // hyperparameters
+  private final float learningRate; // size of hidden layer of neurons
+  private final int seqLength; // no. of steps to unroll the RNN for
+  private final int hiddenLayerSize;
+  private final int epochs;
+  private final int vocabSize;
+  private final Map<Character, Integer> charToIx;
+  private final Map<Integer, Character> ixToChar;
+  private final List<Character> data;
+
+  // model parameters
+  private final INDArray wxh; // input to hidden
+  private final INDArray whh; // hidden to hidden
+  private final INDArray why; // hidden to output
+  private final INDArray bh; // hidden bias
+  private final INDArray by; // output bias
+
+  private INDArray hPrev = null; // memory state
+
+  public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text) {
+    this.learningRate = learningRate;
+    this.seqLength = seqLength;
+    this.hiddenLayerSize = hiddenLayerSize;
+    this.epochs = epochs;
     char[] textChars = text.toCharArray();
-    List<Character> data = new LinkedList<>();
+    data = new LinkedList<>();
     for (char c : textChars) {
       data.add(c);
     }
     Set<Character> chars = new HashSet<>(data);
-    int vocabSize = chars.size();
+    vocabSize = chars.size();
     System.out.printf("data has %d characters, %d unique.", data.size(), 
vocabSize);
-    Map<Character, Integer> charToIx = new HashMap<>();
-    Map<Integer, Character> ixToChar = new HashMap<>();
+    charToIx = new HashMap<>();
+    ixToChar = new HashMap<>();
     int i = 0;
     for (Character c : chars) {
       charToIx.put(c, i);
@@ -61,17 +83,16 @@ public class RNN {
       i++;
     }
 
-    // hyperparameters
-    int hiddenSize = 40; // size of hidden layer of neurons
-    int seqLength = 10; // no. of steps to unroll the RNN for
-    float learningRate = 1e-2f;
-
-    // model parameters
-    INDArray wxh = Nd4j.randn(hiddenSize, vocabSize).mul(0.001); // input to 
hidden
-    INDArray whh = Nd4j.randn(hiddenSize, hiddenSize).mul(0.001); // hidden to 
hidden
-    INDArray why = Nd4j.randn(vocabSize, hiddenSize).mul(0.001); // hidden to 
output
-    INDArray bh = Nd4j.zeros(hiddenSize, 1); // hidden bias
-    INDArray by = Nd4j.zeros(vocabSize, 1); // output bias
+    wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01);
+    whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01);
+    why = Nd4j.randn(vocabSize, hiddenLayerSize).mul(0.01);
+    bh = Nd4j.zeros(hiddenLayerSize, 1).mul(0.01);
+    by = Nd4j.zeros(vocabSize, 1).mul(0.01);
+  }
+
+  public void learn() {
+
+    int currentEpoch = 0;
 
     int n = 0;
     int p = 0;
@@ -87,32 +108,25 @@ public class RNN {
     // loss at iteration 0
     double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;
 
-    INDArray hPrev = null;
     while (true) {
       // prepare inputs (we're sweeping from left to right in steps seqLength 
long)
       if (p + seqLength + 1 >= data.size() || n == 0) {
-        hPrev = Nd4j.zeros(hiddenSize, 1); // reset RNN memory
+        hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
         p = 0; // go from start of data
+        currentEpoch++;
+        if (currentEpoch == epochs) {
+          System.out.println("training finished: e:" + epochs + ", l: " + 
smoothLoss + ", h:(" + learningRate + ", " + seqLength + ", " + hiddenLayerSize 
+ ")");
+          break;
+        }
       }
 
-      INDArray inputs = Nd4j.create(seqLength);
-      int c = 0;
-      for (Character ch : data.subList(p, p + seqLength)) {
-        Integer ix = charToIx.get(ch);
-        inputs.putScalar(c, ix);
-        c++;
-      }
-      INDArray targets = Nd4j.create(seqLength);
-      c = 0;
-      for (Character ch : data.subList(p + 1, p + seqLength + 1)) {
-        Integer ix = charToIx.get(ch);
-        targets.putScalar(c, ix);
-        c++;
-      }
+      INDArray inputs = getSequence(p);
+      INDArray targets = getSequence(p + 1);
 
-      // sample from the model now and then
+      // sample from the model every now and then
       if (n % 1000 == 0) {
-        sample(vocabSize, ixToChar, wxh, whh, why, bh, by, hPrev, inputs);
+        String txt = sample(inputs.getInt(0));
+        System.out.printf("\n---\n %s \n----\n", txt);
       }
 
       INDArray dWxh = Nd4j.zerosLike(wxh);
@@ -125,7 +139,11 @@ public class RNN {
       // forward seqLength characters through the net and fetch gradient
       double loss = lossFun(vocabSize, wxh, whh, why, bh, by, hPrev, inputs, 
targets, dWxh, dWhh, dWhy, dbh, dby);
       smoothLoss = smoothLoss * 0.99 + loss * 0.001;
-      if (n % 100 == 0) {
+      if (Double.isNaN(smoothLoss)) {
+        System.out.println("loss is NaN (over/underflow occured, try adjusting 
hyperparameters)");
+        break;
+      }
+      if (n % 1000 == 0) {
         System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print 
progress
       }
 
@@ -150,6 +168,17 @@ public class RNN {
     }
   }
 
+  private INDArray getSequence(int p) {
+    INDArray inputs = Nd4j.create(seqLength);
+    int c = 0;
+    for (Character ch : data.subList(p, p + seqLength)) {
+      Integer ix = charToIx.get(ch);
+      inputs.putScalar(c, ix);
+      c++;
+    }
+    return inputs;
+  }
+
   /**
    * inputs, targets are both list of integers
    * hprev is Hx1 array of initial hidden state
@@ -233,13 +262,11 @@ public class RNN {
   }
 
   /**
-   * sample a sequence of integers from the model, h is memory state, seed_ix 
is seed letter for first time step
+   * sample a sequence of integers from the model, using current (hPrev) 
memory state, seedIx is seed letter for first time step
    */
-  private void sample(int vocabSize, Map<Integer, Character> ixToChar, 
INDArray wxh, INDArray whh, INDArray why,
-                      INDArray bh, INDArray by, INDArray hPrev, INDArray 
inputs) {
+  public String sample(int seedIx) {
 
     INDArray x = Nd4j.zeros(vocabSize, 1);
-    int seedIx = inputs.getInt(0);
     x.putScalar(seedIx, 1);
     int sampleSize = 200;
     INDArray ixes = Nd4j.create(sampleSize);
@@ -263,15 +290,18 @@ public class RNN {
       ixes.putScalar(t, ix);
     }
 
-    String txt = "";
-
+    StringBuilder txt = new StringBuilder();
 
     NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
     while (ndIndexIterator.hasNext()) {
       int[] next = ndIndexIterator.next();
-      txt += ixToChar.get(ixes.getInt(next));
+      txt.append(ixToChar.get(ixes.getInt(next)));
     }
-    System.out.printf("\n---\n %s \n----\n", txt);
+    return txt.toString();
+  }
+
+  public int getVocabSize() {
+    return vocabSize;
   }
 
 }

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java?rev=1763584&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java Thu Oct  6 
12:58:22 2016
@@ -0,0 +1,309 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.distribution.EnumeratedDistribution;
+import org.apache.commons.math3.util.Pair;
+import org.nd4j.linalg.api.iter.NdIndexIterator;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * A min word-level vanilla RNN model, based on Andrej Karpathy's python code.
+ * See also:
+ *
+ * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness";>The 
Unreasonable Effectiveness of Recurrent Neural Networks</a>
+ * @see <a 
href="https://gist.github.com/karpathy/d4dee566867f8291f086";>Minimal 
character-level language model with a Vanilla Recurrent Neural Network, in 
Python/numpy</a>
+ */
+public class WordRNN {
+
+  // hyperparameters
+  private final float learningRate; // size of hidden layer of neurons
+  private final int seqLength; // no. of steps to unroll the RNN for
+  private final int hiddenLayerSize;
+  private final int epochs;
+  private final int vocabSize;
+  private final Map<String, Integer> stringToIx;
+  private final Map<Integer, String> ixToString;
+  private final List<String> data;
+
+  // model parameters
+  private final INDArray wxh; // input to hidden
+  private final INDArray whh; // hidden to hidden
+  private final INDArray why; // hidden to output
+  private final INDArray bh; // hidden bias
+  private final INDArray by; // output bias
+
+  private INDArray hPrev = null; // memory state
+
+  public WordRNN(float learningRate, int seqLength, int hiddenLayerSize, int 
epochs, String text) {
+    this.learningRate = learningRate;
+    this.seqLength = seqLength;
+    this.hiddenLayerSize = hiddenLayerSize;
+    this.epochs = epochs;
+    String[] textStrings = text.split(" ");
+    data = new LinkedList<>();
+    Collections.addAll(data, textStrings);
+    Set<String> strings = new HashSet<>(data);
+    vocabSize = strings.size();
+    System.out.printf("data has %d words, %d unique.", data.size(), vocabSize);
+    stringToIx = new HashMap<>();
+    ixToString = new HashMap<>();
+    int i = 0;
+    for (String s : strings) {
+      stringToIx.put(s, i);
+      ixToString.put(i, s);
+      i++;
+    }
+
+    wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01);
+    whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01);
+    why = Nd4j.randn(vocabSize, hiddenLayerSize).mul(0.01);
+    bh = Nd4j.zeros(hiddenLayerSize, 1).mul(0.01);
+    by = Nd4j.zeros(vocabSize, 1).mul(0.01);
+  }
+
+  public void learn() {
+
+    int currentEpoch = 0;
+
+    int n = 0;
+    int p = 0;
+
+    // memory variables for Adagrad
+    INDArray mWxh = Nd4j.zerosLike(wxh);
+    INDArray mWhh = Nd4j.zerosLike(whh);
+    INDArray mWhy = Nd4j.zerosLike(why);
+
+    INDArray mbh = Nd4j.zerosLike(bh);
+    INDArray mby = Nd4j.zerosLike(by);
+
+    // loss at iteration 0
+    double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;
+
+    while (true) {
+      // prepare inputs (we're sweeping from left to right in steps seqLength 
long)
+      if (p + seqLength + 1 >= data.size() || n == 0) {
+        hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory
+        p = 0; // go from start of data
+        currentEpoch++;
+        if (currentEpoch == epochs) {
+          System.out.println("training finished: e:" + epochs + ", l: " + 
smoothLoss + ", h:(" + learningRate + ", " + seqLength + ", " + hiddenLayerSize 
+ ")");
+          break;
+        }
+      }
+
+      INDArray inputs = getSequence(p);
+      INDArray targets = getSequence(p + 1);
+
+      // sample from the model every now and then
+      if (n % 1000 == 0) {
+        String txt = sample(inputs.getInt(0));
+        System.out.printf("\n---\n %s \n----\n", txt);
+      }
+
+      INDArray dWxh = Nd4j.zerosLike(wxh);
+      INDArray dWhh = Nd4j.zerosLike(whh);
+      INDArray dWhy = Nd4j.zerosLike(why);
+
+      INDArray dbh = Nd4j.zerosLike(bh);
+      INDArray dby = Nd4j.zerosLike(by);
+
+      // forward seqLength characters through the net and fetch gradient
+      double loss = lossFun(vocabSize, wxh, whh, why, bh, by, hPrev, inputs, 
targets, dWxh, dWhh, dWhy, dbh, dby);
+      smoothLoss = smoothLoss * 0.99 + loss * 0.001;
+      if (Double.isNaN(smoothLoss)) {
+        System.out.println("loss is NaN (over/underflow occured, try adjusting 
hyperparameters)");
+        break;
+      }
+      if (n % 1000 == 0) {
+        System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print 
progress
+      }
+
+      // perform parameter update with Adagrad
+      mWxh.addi(dWxh.mul(dWxh));
+      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8))));
+
+      mWhh.addi(dWhh.mul(dWhh));
+      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8))));
+
+      mWhy.addi(dWhy.mul(dWhy));
+      why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.addi(1e-8))));
+
+      mbh.addi(dbh.mul(dbh));
+      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8))));
+
+      mby.addi(dby.mul(dby));
+      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8))));
+
+      p += seqLength; // move data pointer
+      n++; // iteration counter
+    }
+  }
+
+  private INDArray getSequence(int p) {
+    INDArray inputs = Nd4j.create(seqLength);
+    int c = 0;
+    for (String s : data.subList(p, p + seqLength)) {
+      Integer ix = stringToIx.get(s);
+      inputs.putScalar(c, ix);
+      c++;
+    }
+    return inputs;
+  }
+
+  /**
+   * inputs, targets are both list of integers
+   * hprev is Hx1 array of initial hidden state
+   * returns the loss, gradients on model parameters and last hidden state
+   */
+  private double lossFun(int vocabSize, INDArray wxh, INDArray whh, INDArray 
why, INDArray bh, INDArray by, INDArray hPrev,
+                         INDArray inputs, INDArray targets, INDArray dWxh, 
INDArray dWhh, INDArray dWhy, INDArray dbh,
+                         INDArray dby) {
+
+    INDArray xs = Nd4j.zeros(inputs.length(), vocabSize);
+    INDArray hs = null;
+    INDArray ys = null;
+    INDArray ps = null;
+
+    INDArray hs1 = Nd4j.create(hPrev.shape());
+    Nd4j.copy(hPrev, hs1);
+
+    double loss = 0;
+
+    // forward pass
+    for (int t = 0; t < inputs.length(); t++) {
+      int tIndex = inputs.getScalar(t).getInt(0);
+      xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation
+      INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
+      INDArray hst = 
Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh)));
 // hidden state
+      if (hs == null) {
+        hs = init(inputs.length(), hst);
+      }
+      hs.putRow(t, hst);
+
+      INDArray yst = (why.mmul(hst)).add(by); // unnormalized log 
probabilities for next chars
+      if (ys == null) {
+        ys = init(inputs.length(), yst);
+      }
+      ys.putRow(t, yst);
+      INDArray exp = Transforms.exp(yst);
+      Number sumExp = exp.sumNumber();
+      INDArray pst = exp.div(sumExp); // probabilities for next chars
+      if (ps == null) {
+        ps = init(inputs.length(), pst);
+      }
+      ps.putRow(t, pst);
+      loss += -Transforms.log(ps.getRow(t).getRow(targets.getInt(t)), 
true).sumNumber().doubleValue(); // softmax (cross-entropy loss)
+    }
+
+    // backward pass: compute gradients going backwards
+    INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
+    for (int t = inputs.length() - 1; t >= 0; t--) {
+      INDArray dy = ps.getRow(t).dup();
+      dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // 
backprop into y
+      INDArray hst = hs.getRow(t);
+      dWhy.addi(dy.mmul(hst.transpose()));
+      dby.addi(dy);
+      INDArray dh = why.transpose().mmul(dy).add(dhNext); // backprop into h
+      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // 
backprop through tanh nonlinearity
+      dbh.addi(dhraw);
+      dWxh.addi(dhraw.mmul(xs.getRow(t)));
+      INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
+      dWhh.addi(dhraw.mmul(hsRow.transpose()));
+      dhNext = whh.transpose().mmul(dhraw);
+
+    }
+    // clip exploding gradients
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWhy, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -5, 5));
+
+    return loss;
+  }
+
+  private INDArray init(int t, INDArray ast) {
+    INDArray as;
+    int[] aShape = ast.shape();
+    int[] shape = new int[1 + aShape.length];
+    shape[0] = t;
+    System.arraycopy(aShape, 0, shape, 1, aShape.length);
+    as = Nd4j.create(shape);
+    return as;
+  }
+
+  /**
+   * sample a sequence of integers from the model, using current (hPrev) 
memory state, seedIx is seed letter for first time step
+   */
+  public String sample(int seedIx) {
+
+    INDArray x = Nd4j.zeros(vocabSize, 1);
+    x.putScalar(seedIx, 1);
+    int sampleSize = 200;
+    INDArray ixes = Nd4j.create(sampleSize);
+
+    for (int t = 0; t < sampleSize; t++) {
+      INDArray h = 
Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh)));
+      INDArray y = (why.mmul(h)).add(by);
+      INDArray exp = Transforms.exp(y);
+      INDArray pm = exp.div(Nd4j.sum(exp)).ravel();
+
+      List<Pair<Integer, Double>> d = new LinkedList<>();
+      for (int pi = 0; pi < vocabSize; pi++) {
+        d.add(new Pair<>(pi, pm.getDouble(0, pi)));
+      }
+      EnumeratedDistribution<Integer> distribution = new 
EnumeratedDistribution<>(d);
+
+      int ix = distribution.sample();
+
+      x = Nd4j.zeros(vocabSize, 1);
+      x.putScalar(ix, 1);
+      ixes.putScalar(t, ix);
+    }
+
+    StringBuilder txt = new StringBuilder();
+
+    NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
+    while (ndIndexIterator.hasNext()) {
+      int[] next = ndIndexIterator.next();
+      if (txt.length() > 0) {
+        txt.append(' ');
+      }
+      txt.append(ixToString.get(ixes.getInt(next)));
+    }
+    return txt.toString();
+  }
+
+  public int getVocabSize() {
+    return vocabSize;
+  }
+
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/WordRNN.java
------------------------------------------------------------------------------
    svn:eol-style = native

Copied: 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
 (from r1763464, labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java)
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java&r1=1763464&r2=1763584&rev=1763584&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java
 Thu Oct  6 12:58:22 2016
@@ -20,20 +20,65 @@ package org.apache.yay;
 
 import org.apache.commons.io.IOUtils;
 import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
 
 import java.io.InputStream;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
 
 /**
- * Tests for {@link RNN}
+ * CV tests for {@link CharRNN}
  */
-public class RNNTest {
+@RunWith(Parameterized.class)
+public class CharRNNCrossValidationTest {
+
+  private float learningRate;
+  private int seqLength;
+  private int hiddenLayerSize;
+  private Random r = new Random();
+
+  public CharRNNCrossValidationTest(float learningRate, int seqLength, int 
hiddenLayerSize) {
+    this.learningRate = learningRate;
+    this.seqLength = seqLength;
+    this.hiddenLayerSize = hiddenLayerSize;
+  }
+
+  @Parameterized.Parameters
+  public static Collection<Object[]> data() {
+    return Arrays.asList(new Object[][]{
+            {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 250, 
512},
+            {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 100, 
30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15},
+            {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 100, 
256}, {1e-2f, 100, 512}, {1e-2f, 100, 128},
+            {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 
100, 256},
+    });
+  }
 
   @Test
-  public void test() throws Exception {
+  public void testLearnWithDifferentHyperparameters() throws Exception {
+    System.out.println("hyperparameters: " + learningRate + ", " + seqLength + 
", " + hiddenLayerSize);
     InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/abstracts.txt");
     String text = IOUtils.toString(resourceAsStream);
-    RNN n = new RNN();
-    n.learn(text);
+    int epochs = 20;
+    CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text);
+    List<String> words = Arrays.asList(text.split(" "));
+    charRNN.learn();
+    for (int i = 0; i < 10; i++) {
+      double c = 0;
+      String sample = charRNN.sample(r.nextInt(charRNN.getVocabSize()));
+      String[] sampleWords = sample.split(" ");
+      for (String sw : sampleWords) {
+        if (words.contains(sw)) {
+          c++;
+        }
+      }
+      if (c > 0) {
+        c /= sample.length();
+      }
+      System.out.println("correct word ratio: " + c);
+    }
   }
 
 }
\ No newline at end of file

Added: 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java?rev=1763584&view=auto
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
 (added)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
 Thu Oct  6 12:58:22 2016
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.io.IOUtils;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+import java.io.InputStream;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.List;
+import java.util.Random;
+
+/**
+ * CV tests for {@link WordRNN}
+ */
+@RunWith(Parameterized.class)
+public class WordRNNCrossValidationTest {
+
+  private float learningRate;
+  private int seqLength;
+  private int hiddenLayerSize;
+  private Random r = new Random();
+
+  public WordRNNCrossValidationTest(float learningRate, int seqLength, int 
hiddenLayerSize) {
+    this.learningRate = learningRate;
+    this.seqLength = seqLength;
+    this.hiddenLayerSize = hiddenLayerSize;
+  }
+
+  @Parameterized.Parameters
+  public static Collection<Object[]> data() {
+    return Arrays.asList(new Object[][]{
+            {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 250, 
512},
+            {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 100, 
30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15},
+            {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 100, 
256}, {1e-2f, 100, 512}, {1e-2f, 100, 128},
+            {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 
100, 256},
+            {1e-4f, 200, 1000},
+    });
+  }
+
+  @Test
+  public void testLearnWithDifferentHyperparameters() throws Exception {
+    System.out.println("hyperparameters: " + learningRate + ", " + seqLength + 
", " + hiddenLayerSize);
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/abstracts.txt");
+    String text = IOUtils.toString(resourceAsStream);
+    int epochs = 100;
+    WordRNN wordRNN = new WordRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text);
+    List<String> words = Arrays.asList(text.split(" "));
+    wordRNN.learn();
+    for (int i = 0; i < 10; i++) {
+      double c = 0;
+      String sample = wordRNN.sample(r.nextInt(wordRNN.getVocabSize()));
+      String[] sampleWords = sample.split(" ");
+      for (String sw : sampleWords) {
+        if (words.contains(sw)) {
+          c++;
+        }
+      }
+      if (c > 0) {
+        c /= sample.length();
+      }
+      System.out.println("correct word ratio: " + c);
+    }
+  }
+
+}
\ No newline at end of file

Propchange: 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java
------------------------------------------------------------------------------
    svn:eol-style = native



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to