Author: tommaso Date: Fri Oct 14 07:18:02 2016 New Revision: 1764824 URL: http://svn.apache.org/viewvc?rev=1764824&view=rev Log: added apprunner plugin, fixed bug in sRNN, adjusted test
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (with props) Modified: labs/yay/trunk/core/pom.xml labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Modified: labs/yay/trunk/core/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1764824&r1=1764823&r2=1764824&view=diff ============================================================================== --- labs/yay/trunk/core/pom.xml (original) +++ labs/yay/trunk/core/pom.xml Fri Oct 14 07:18:02 2016 @@ -72,6 +72,27 @@ <argLine>-Xmx8g</argLine> </configuration> </plugin> + <plugin> + <groupId>org.codehaus.mojo</groupId> + <artifactId>appassembler-maven-plugin</artifactId> + <version>1.1.1</version> + <executions> + <execution> + <phase>package</phase> + <goals> + <goal>assemble</goal> + </goals> + <configuration> + <programs> + <program> + <mainClass>org.apache.yay.NNRunner</mainClass> + <name>nn</name> + </program> + </programs> + </configuration> + </execution> + </executions> + </plugin> </plugins> </build> </project> \ No newline at end of file Added: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1764824&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Fri Oct 14 07:18:02 2016 @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Random; + +/** + * Runner class for different types of neural network + */ +public class NNRunner { + + public static void main(String[] args) { + if (args.length > 0) { + Random random = new Random(); + String nnType = args[0]; + switch (nnType) { + case "skipgram": { + // skipgram network + break; + } + case "recurrent": { + // recurrent neural network + // e.g. bin/nn recurrent core/src/test/resources/word2vec/sentences.txt true 100 25 100 stacked + float learningRate = 1e-2f; + int seqLength = 25; + int hiddenLayerSize = 30; + int epochs = 20; + boolean useChars = true; + String text = ""; + + if (args.length > 1 && args[1] != null) { + Path path = Paths.get(args[1]); + try { + byte[] bytes = Files.readAllBytes(path); + text = new String(bytes); + } catch (IOException e) { + throw new RuntimeException("could not read from path " + path); + } + } + if (args.length > 2 && args[2] != null) { + useChars = Boolean.valueOf(args[2]); + } + if (args.length > 3 && args[3] != null) { + epochs = Integer.valueOf(args[3]); + } + if (args.length > 4 && args[4] != null) { + hiddenLayerSize = Integer.valueOf(args[4]); + } + if (args.length > 5 && args[5] != null) { + seqLength = Integer.valueOf(args[5]); + } + RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); + if (args.length > 6 && args[6] != null) { + if ("stacked".equals(args[6])) { + rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); + } + } + + rnn.learn(); + int seed = random.nextInt(rnn.vocabSize); + System.out.println(rnn.sample(seed)); + break; + } + case "multi": { + // multi layer network + break; + } + } + } + } +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1764824&r1=1764823&r2=1764824&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Fri Oct 14 07:18:02 2016 @@ -26,6 +26,11 @@ import org.nd4j.linalg.api.ops.impl.tran import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Date; import java.util.HashMap; import java.util.HashSet; import java.util.LinkedList; @@ -81,7 +86,7 @@ public class RNN { Set<String> tokens = new HashSet<>(data); vocabSize = tokens.size(); - System.out.printf("data has %d tokens, %d unique.", data.size(), vocabSize); + System.out.printf("data has %d tokens, %d unique.\n", data.size(), vocabSize); charToIx = new HashMap<>(); ixToChar = new HashMap<>(); int i = 0; @@ -140,7 +145,7 @@ public class RNN { INDArray targets = getSequence(p + 1); // sample from the model every now and then - if (n % 1000 == 0) { + if (n % 1000 == 0 && n > 0) { String txt = sample(inputs.getInt(0)); System.out.printf("\n---\n %s \n----\n", txt); } @@ -330,7 +335,7 @@ public class RNN { @Override public String toString() { - return "RNN{" + + return getClass().getName() + "{" + "learningRate=" + learningRate + ", seqLength=" + seqLength + ", hiddenLayerSize=" + hiddenLayerSize + @@ -342,7 +347,7 @@ public class RNN { public String getHyperparamsString() { - return "RNN{" + + return getClass().getName() + "{" + "wxh=" + wxh + ", whh=" + whh + ", why=" + why + @@ -350,4 +355,20 @@ public class RNN { ", by=" + by + '}'; } + + public void serialize(String prefix) throws IOException { + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(prefix + new Date().toString() + ".csv"))); + bufferedWriter.write("wxh"); + bufferedWriter.write(wxh.toString()); + bufferedWriter.write("whh"); + bufferedWriter.write(whh.toString()); + bufferedWriter.write("why"); + bufferedWriter.write(why.toString()); + bufferedWriter.write("bh"); + bufferedWriter.write(bh.toString()); + bufferedWriter.write("by"); + bufferedWriter.write(by.toString()); + bufferedWriter.flush(); + bufferedWriter.close(); + } } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1764824&r1=1764823&r2=1764824&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Oct 14 07:18:02 2016 @@ -25,6 +25,11 @@ import org.nd4j.linalg.api.ops.impl.tran import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Date; import java.util.LinkedList; import java.util.List; @@ -104,7 +109,7 @@ public class StackedRNN extends RNN { INDArray targets = getSequence(p + 1); // sample from the model every now and then - if (n % 1000 == 0) { + if (n % 1000 == 0 && n > 0) { String txt = sample(inputs.getInt(0)); System.out.printf("\n---\n %s \n----\n", txt); } @@ -195,9 +200,9 @@ public class StackedRNN extends RNN { if (hs2 == null) { hs2 = init(inputs.length(), hst2); } - hs.putRow(t, hst); + hs2.putRow(t, hst2); - INDArray yst = (wh2y.mmul(hst)).add(by); // unnormalized log probabilities for next chars + INDArray yst = (wh2y.mmul(hst2)).add(by); // unnormalized log probabilities for next chars if (ys == null) { ys = init(inputs.length(), yst); } @@ -226,6 +231,7 @@ public class StackedRNN extends RNN { dby.addi(dy); INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // backprop into h2 INDArray dhraw2 = (Nd4j.ones(hst2.shape()).sub(hst2).mul(hst2)).mul(dh2); // backprop through tanh nonlinearity +// INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new SetRange(hst2, 0, Double.MAX_VALUE));; // backprop through relu nonlinearity dbh2.addi(dhraw2); INDArray hst = hs.getRow(t); @@ -233,6 +239,7 @@ public class StackedRNN extends RNN { dbh.addi(dh2); INDArray dh = whh2.transpose().mmul(dh2).add(dhNext); // backprop into h INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // backprop through tanh nonlinearity +// INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 0, Double.MAX_VALUE));; // backprop through relu nonlinearity dbh.addi(dhraw); dWxh.addi(dhraw.mmul(xs.getRow(t))); @@ -255,6 +262,7 @@ public class StackedRNN extends RNN { /** * sample a sequence of integers from the model, using current (hPrev) memory state, seedIx is seed letter for first time step */ + @Override public String sample(int seedIx) { INDArray x = Nd4j.zeros(vocabSize, 1); @@ -285,4 +293,25 @@ public class StackedRNN extends RNN { return getSampleString(ixes); } + @Override + public void serialize(String prefix) throws IOException { + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File(prefix + new Date().toString() + ".csv"))); + bufferedWriter.write("wxh"); + bufferedWriter.write(wxh.toString()); + bufferedWriter.write("whh"); + bufferedWriter.write(whh.toString()); + bufferedWriter.write("whh2"); + bufferedWriter.write(whh2.toString()); + bufferedWriter.write("wh2y"); + bufferedWriter.write(wh2y.toString()); + bufferedWriter.write("bh"); + bufferedWriter.write(bh.toString()); + bufferedWriter.write("bh2"); + bufferedWriter.write(bh2.toString()); + bufferedWriter.write("by"); + bufferedWriter.write(by.toString()); + bufferedWriter.flush(); + bufferedWriter.close(); + } + } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java?rev=1764824&r1=1764823&r2=1764824&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNCrossValidationTest.java Fri Oct 14 07:18:02 2016 @@ -19,6 +19,7 @@ package org.apache.yay; import org.apache.commons.io.IOUtils; +import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -39,6 +40,9 @@ public class RNNCrossValidationTest { private int seqLength; private int hiddenLayerSize; private Random r = new Random(); + private String text; + private final int epochs = 2; + private List<String> words; public RNNCrossValidationTest(float learningRate, int seqLength, int hiddenLayerSize) { this.learningRate = learningRate; @@ -46,71 +50,73 @@ public class RNNCrossValidationTest { this.hiddenLayerSize = hiddenLayerSize; } + @Before + public void setUp() throws Exception { + InputStream stream = getClass().getResourceAsStream("/word2vec/abstracts.txt"); + text = IOUtils.toString(stream); + words = Arrays.asList(text.split(" ")); + stream.close(); + } + @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][]{ - {1e-1f, 25, 100}, - {3e-1f, 25, 100}, - {3e-1f, 100, 25}, - {1e-1f, 100, 25}, + {3e-1f, 50, 5}, + {3e-1f, 50, 10}, + {3e-1f, 50, 15}, + {3e-1f, 50, 25}, + {3e-1f, 50, 50}, + {3e-1f, 50, 100}, }); } @Test public void testVanillaCharRNNLearn() throws Exception { - InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); - String text = IOUtils.toString(resourceAsStream); - int epochs = 10; - RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text); - checkCorrectWordsRatio(text, RNN); + RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + evaluate(rnn, true); + rnn.serialize("target/crnn-weights-"); } @Test public void testVanillaWordRNNLearn() throws Exception { - InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); - String text = IOUtils.toString(resourceAsStream); - int epochs = 100; - RNN RNN = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); - checkCorrectWordsRatio(text, RNN); + RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); + evaluate(rnn, false); + rnn.serialize("target/wrnn-weights-"); } @Test public void testStackedCharRNNLearn() throws Exception { - InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); - String text = IOUtils.toString(resourceAsStream); - int epochs = 100; - RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); - checkCorrectWordsRatio(text, RNN); + RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + evaluate(rnn, true); + rnn.serialize("target/scrnn-weights-"); } @Test public void testStackedWordRNNLearn() throws Exception { - InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); - String text = IOUtils.toString(resourceAsStream); - int epochs = 100; - RNN RNN = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); - checkCorrectWordsRatio(text, RNN); + RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); + evaluate(rnn, false); + rnn.serialize("target/swrnn-weights-"); } - private void checkCorrectWordsRatio(String text, RNN RNN) { - System.out.println(RNN); - List<String> words = Arrays.asList(text.split(" ")); - RNN.learn(); + private void evaluate(RNN rnn, boolean checkRatio) { + System.out.println(rnn); + rnn.learn(); + double c = 0; for (int i = 0; i < 10; i++) { - double c = 0; - String sample = RNN.sample(r.nextInt(RNN.getVocabSize())); - String[] sampleWords = sample.split(" "); - for (String sw : sampleWords) { - if (words.contains(sw)) { - c++; + String sample = rnn.sample(r.nextInt(rnn.getVocabSize())); + if (checkRatio) { + String[] sampleWords = sample.split(" "); + for (String sw : sampleWords) { + if (words.contains(sw)) { + c++; + } + } + if (c > 0) { + c /= sampleWords.length; } } - if (c > 0) { - c /= sample.length(); - } - System.out.println("correct word ratio: " + c); - } + System.out.println("average correct word ratio: " + (c / 10d)); } } \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org