Author: tommaso
Date: Fri Oct 14 07:18:02 2016
New Revision: 1764824

URL: http://svn.apache.org/viewvc?rev=1764824&view=rev
Log:
added apprunner plugin, fixed bug in sRNN, adjusted test

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java   (with 
props)
Modified:
    labs/yay/trunk/core/pom.xml
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java

Modified: labs/yay/trunk/core/pom.xml
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1764824&r1=1764823&r2=1764824&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Fri Oct 14 07:18:02 2016
@@ -72,6 +72,27 @@
                     <argLine>-Xmx8g</argLine>
                 </configuration>
             </plugin>
+            <plugin>
+                <groupId>org.codehaus.mojo</groupId>
+                <artifactId>appassembler-maven-plugin</artifactId>
+                <version>1.1.1</version>
+                <executions>
+                    <execution>
+                        <phase>package</phase>
+                        <goals>
+                            <goal>assemble</goal>
+                        </goals>
+                        <configuration>
+                            <programs>
+                                <program>
+                                    
<mainClass>org.apache.yay.NNRunner</mainClass>
+                                    <name>nn</name>
+                                </program>
+                            </programs>
+                        </configuration>
+                    </execution>
+                </executions>
+            </plugin>
         </plugins>
     </build>
 </project>
\ No newline at end of file

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1764824&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Fri Oct 14 
07:18:02 2016
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.Random;
+
+/**
+ * Runner class for different types of neural network
+ */
+public class NNRunner {
+
+  public static void main(String[] args) {
+    if (args.length > 0) {
+      Random random = new Random();
+      String nnType = args[0];
+      switch (nnType) {
+        case "skipgram": {
+          // skipgram network
+          break;
+        }
+        case "recurrent": {
+          // recurrent neural network
+          // e.g. bin/nn recurrent 
core/src/test/resources/word2vec/sentences.txt true 100 25 100 stacked
+          float learningRate = 1e-2f;
+          int seqLength = 25;
+          int hiddenLayerSize = 30;
+          int epochs = 20;
+          boolean useChars = true;
+          String text = "";
+
+          if (args.length > 1 && args[1] != null) {
+            Path path = Paths.get(args[1]);
+            try {
+              byte[] bytes = Files.readAllBytes(path);
+              text = new String(bytes);
+            } catch (IOException e) {
+              throw new RuntimeException("could not read from path " + path);
+            }
+          }
+          if (args.length > 2 && args[2] != null) {
+            useChars = Boolean.valueOf(args[2]);
+          }
+          if (args.length > 3 && args[3] != null) {
+            epochs = Integer.valueOf(args[3]);
+          }
+          if (args.length > 4 && args[4] != null) {
+            hiddenLayerSize = Integer.valueOf(args[4]);
+          }
+          if (args.length > 5 && args[5] != null) {
+            seqLength = Integer.valueOf(args[5]);
+          }
+          RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text, useChars);
+          if (args.length > 6 && args[6] != null) {
+            if ("stacked".equals(args[6])) {
+              rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, 
epochs, text, useChars);
+            }
+          }
+
+          rnn.learn();
+          int seed = random.nextInt(rnn.vocabSize);
+          System.out.println(rnn.sample(seed));
+          break;
+        }
+        case "multi": {
+          // multi layer network
+          break;
+        }
+      }
+    }
+  }
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1764824&r1=1764823&r2=1764824&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Fri Oct 14 
07:18:02 2016
@@ -26,6 +26,11 @@ import org.nd4j.linalg.api.ops.impl.tran
 import org.nd4j.linalg.factory.Nd4j;
 import org.nd4j.linalg.ops.transforms.Transforms;
 
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.util.Date;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedList;
@@ -81,7 +86,7 @@ public class RNN {
     Set<String> tokens = new HashSet<>(data);
     vocabSize = tokens.size();
 
-    System.out.printf("data has %d tokens, %d unique.", data.size(), 
vocabSize);
+    System.out.printf("data has %d tokens, %d unique.\n", data.size(), 
vocabSize);
     charToIx = new HashMap<>();
     ixToChar = new HashMap<>();
     int i = 0;
@@ -140,7 +145,7 @@ public class RNN {
       INDArray targets = getSequence(p + 1);
 
       // sample from the model every now and then
-      if (n % 1000 == 0) {
+      if (n % 1000 == 0 && n > 0) {
         String txt = sample(inputs.getInt(0));
         System.out.printf("\n---\n %s \n----\n", txt);
       }
@@ -330,7 +335,7 @@ public class RNN {
 
   @Override
   public String toString() {
-    return "RNN{" +
+    return getClass().getName() + "{" +
             "learningRate=" + learningRate +
             ", seqLength=" + seqLength +
             ", hiddenLayerSize=" + hiddenLayerSize +
@@ -342,7 +347,7 @@ public class RNN {
 
 
   public String getHyperparamsString() {
-    return "RNN{" +
+    return getClass().getName() + "{" +
             "wxh=" + wxh +
             ", whh=" + whh +
             ", why=" + why +
@@ -350,4 +355,20 @@ public class RNN {
             ", by=" + by +
             '}';
   }
+
+  public void serialize(String prefix) throws IOException {
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File(prefix + new Date().toString() + ".csv")));
+    bufferedWriter.write("wxh");
+    bufferedWriter.write(wxh.toString());
+    bufferedWriter.write("whh");
+    bufferedWriter.write(whh.toString());
+    bufferedWriter.write("why");
+    bufferedWriter.write(why.toString());
+    bufferedWriter.write("bh");
+    bufferedWriter.write(bh.toString());
+    bufferedWriter.write("by");
+    bufferedWriter.write(by.toString());
+    bufferedWriter.flush();
+    bufferedWriter.close();
+  }
 }

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1764824&r1=1764823&r2=1764824&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Oct 14 
07:18:02 2016
@@ -25,6 +25,11 @@ import org.nd4j.linalg.api.ops.impl.tran
 import org.nd4j.linalg.factory.Nd4j;
 import org.nd4j.linalg.ops.transforms.Transforms;
 
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.util.Date;
 import java.util.LinkedList;
 import java.util.List;
 
@@ -104,7 +109,7 @@ public class StackedRNN extends RNN {
       INDArray targets = getSequence(p + 1);
 
       // sample from the model every now and then
-      if (n % 1000 == 0) {
+      if (n % 1000 == 0 && n > 0) {
         String txt = sample(inputs.getInt(0));
         System.out.printf("\n---\n %s \n----\n", txt);
       }
@@ -195,9 +200,9 @@ public class StackedRNN extends RNN {
       if (hs2 == null) {
         hs2 = init(inputs.length(), hst2);
       }
-      hs.putRow(t, hst);
+      hs2.putRow(t, hst2);
 
-      INDArray yst = (wh2y.mmul(hst)).add(by); // unnormalized log 
probabilities for next chars
+      INDArray yst = (wh2y.mmul(hst2)).add(by); // unnormalized log 
probabilities for next chars
       if (ys == null) {
         ys = init(inputs.length(), yst);
       }
@@ -226,6 +231,7 @@ public class StackedRNN extends RNN {
       dby.addi(dy);
       INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // backprop into 
h2
       INDArray dhraw2 = 
(Nd4j.ones(hst2.shape()).sub(hst2).mul(hst2)).mul(dh2); // backprop through 
tanh nonlinearity
+//      INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new 
SetRange(hst2, 0, Double.MAX_VALUE));; // backprop through relu nonlinearity
       dbh2.addi(dhraw2);
 
       INDArray hst = hs.getRow(t);
@@ -233,6 +239,7 @@ public class StackedRNN extends RNN {
       dbh.addi(dh2);
       INDArray dh = whh2.transpose().mmul(dh2).add(dhNext); // backprop into h
       INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // 
backprop through tanh nonlinearity
+//      INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 
0, Double.MAX_VALUE));; // backprop through relu nonlinearity
       dbh.addi(dhraw);
 
       dWxh.addi(dhraw.mmul(xs.getRow(t)));
@@ -255,6 +262,7 @@ public class StackedRNN extends RNN {
   /**
    * sample a sequence of integers from the model, using current (hPrev) 
memory state, seedIx is seed letter for first time step
    */
+  @Override
   public String sample(int seedIx) {
 
     INDArray x = Nd4j.zeros(vocabSize, 1);
@@ -285,4 +293,25 @@ public class StackedRNN extends RNN {
     return getSampleString(ixes);
   }
 
+  @Override
+  public void serialize(String prefix) throws IOException {
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File(prefix + new Date().toString() + ".csv")));
+    bufferedWriter.write("wxh");
+    bufferedWriter.write(wxh.toString());
+    bufferedWriter.write("whh");
+    bufferedWriter.write(whh.toString());
+    bufferedWriter.write("whh2");
+    bufferedWriter.write(whh2.toString());
+    bufferedWriter.write("wh2y");
+    bufferedWriter.write(wh2y.toString());
+    bufferedWriter.write("bh");
+    bufferedWriter.write(bh.toString());
+    bufferedWriter.write("bh2");
+    bufferedWriter.write(bh2.toString());
+    bufferedWriter.write("by");
+    bufferedWriter.write(by.toString());
+    bufferedWriter.flush();
+    bufferedWriter.close();
+  }
+
 }

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?rev=1764824&r1=1764823&r2=1764824&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java 
(original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java 
Fri Oct 14 07:18:02 2016
@@ -19,6 +19,7 @@
 package org.apache.yay;
 
 import org.apache.commons.io.IOUtils;
+import org.junit.Before;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -39,6 +40,9 @@ public class RNNCrossValidationTest {
   private int seqLength;
   private int hiddenLayerSize;
   private Random r = new Random();
+  private String text;
+  private final int epochs = 2;
+  private List<String> words;
 
   public RNNCrossValidationTest(float learningRate, int seqLength, int 
hiddenLayerSize) {
     this.learningRate = learningRate;
@@ -46,71 +50,73 @@ public class RNNCrossValidationTest {
     this.hiddenLayerSize = hiddenLayerSize;
   }
 
+  @Before
+  public void setUp() throws Exception {
+    InputStream stream = 
getClass().getResourceAsStream("/word2vec/abstracts.txt");
+    text = IOUtils.toString(stream);
+    words = Arrays.asList(text.split(" "));
+    stream.close();
+  }
+
   @Parameterized.Parameters
   public static Collection<Object[]> data() {
     return Arrays.asList(new Object[][]{
-            {1e-1f, 25, 100},
-            {3e-1f, 25, 100},
-            {3e-1f, 100, 25},
-            {1e-1f, 100, 25},
+            {3e-1f, 50, 5},
+            {3e-1f, 50, 10},
+            {3e-1f, 50, 15},
+            {3e-1f, 50, 25},
+            {3e-1f, 50, 50},
+            {3e-1f, 50, 100},
     });
   }
 
   @Test
   public void testVanillaCharRNNLearn() throws Exception {
-    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
-    String text = IOUtils.toString(resourceAsStream);
-    int epochs = 10;
-    RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
-    checkCorrectWordsRatio(text, RNN);
+    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text);
+    evaluate(rnn, true);
+    rnn.serialize("target/crnn-weights-");
   }
 
   @Test
   public void testVanillaWordRNNLearn() throws Exception {
-    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
-    String text = IOUtils.toString(resourceAsStream);
-    int epochs = 100;
-    RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 
false);
-    checkCorrectWordsRatio(text, RNN);
+    RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 
false);
+    evaluate(rnn, false);
+    rnn.serialize("target/wrnn-weights-");
   }
 
   @Test
   public void testStackedCharRNNLearn() throws Exception {
-    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
-    String text = IOUtils.toString(resourceAsStream);
-    int epochs = 100;
-    RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text);
-    checkCorrectWordsRatio(text, RNN);
+    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text);
+    evaluate(rnn, true);
+    rnn.serialize("target/scrnn-weights-");
   }
 
   @Test
   public void testStackedWordRNNLearn() throws Exception {
-    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
-    String text = IOUtils.toString(resourceAsStream);
-    int epochs = 100;
-    RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text, false);
-    checkCorrectWordsRatio(text, RNN);
+    RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, 
text, false);
+    evaluate(rnn, false);
+    rnn.serialize("target/swrnn-weights-");
   }
 
-  private void checkCorrectWordsRatio(String text, RNN RNN) {
-    System.out.println(RNN);
-    List<String> words = Arrays.asList(text.split(" "));
-    RNN.learn();
+  private void evaluate(RNN rnn, boolean checkRatio) {
+    System.out.println(rnn);
+    rnn.learn();
+    double c = 0;
     for (int i = 0; i < 10; i++) {
-      double c = 0;
-      String sample = RNN.sample(r.nextInt(RNN.getVocabSize()));
-      String[] sampleWords = sample.split(" ");
-      for (String sw : sampleWords) {
-        if (words.contains(sw)) {
-          c++;
+      String sample = rnn.sample(r.nextInt(rnn.getVocabSize()));
+      if (checkRatio) {
+        String[] sampleWords = sample.split(" ");
+        for (String sw : sampleWords) {
+          if (words.contains(sw)) {
+            c++;
+          }
+        }
+        if (c > 0) {
+          c /= sampleWords.length;
         }
       }
-      if (c > 0) {
-        c /= sample.length();
-      }
-      System.out.println("correct word ratio: " + c);
-
     }
+    System.out.println("average correct word ratio: " + (c / 10d));
   }
 
 }
\ No newline at end of file



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to