Author: tommaso Date: Wed Oct 5 16:12:24 2016 New Revision: 1763464 URL: http://svn.apache.org/viewvc?rev=1763464&view=rev Log: minor fixes, java raw char-rnn model inspired by karpathy
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (with props) labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java (with props) Modified: labs/yay/trunk/core/pom.xml labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Modified: labs/yay/trunk/core/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1763464&r1=1763463&r2=1763464&view=diff ============================================================================== --- labs/yay/trunk/core/pom.xml (original) +++ labs/yay/trunk/core/pom.xml Wed Oct 5 16:12:24 2016 @@ -25,6 +25,9 @@ <version>0.2-SNAPSHOT</version> <relativePath>../</relativePath> </parent> + <properties> + <nd4j.version>0.6.0</nd4j.version> + </properties> <name>Yay core</name> <dependencies> <dependency> @@ -52,6 +55,12 @@ <artifactId>guava</artifactId> <version>18.0</version> </dependency> + <dependency> + <groupId>org.nd4j</groupId> + <artifactId>nd4j-native-platform</artifactId> + <version>${nd4j.version}</version> + </dependency> + </dependencies> <build> <plugins> Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java?rev=1763464&r1=1763463&r2=1763464&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java Wed Oct 5 16:12:24 2016 @@ -110,9 +110,7 @@ public class MultiLayerNetwork { while (true) { if (iterations % (1 + (configuration.maxIterations / 100)) == 0) { long time = (System.currentTimeMillis() - start) / 1000; -// if (time > 60) { System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); -// } } // current training example Sample sample = samples[iterations % samples.length]; Added: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1763464&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Wed Oct 5 16:12:24 2016 @@ -0,0 +1,277 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.math3.distribution.EnumeratedDistribution; +import org.apache.commons.math3.util.Pair; +import org.nd4j.linalg.api.iter.NdIndexIterator; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.SetRange; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * A min char-level vanilla RNN model, based on Andrej Karpathy's python code. + * See also: + * + * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness">The Unreasonable Effectiveness of Recurrent Neural Networks</a> + * @see <a href="https://gist.github.com/karpathy/d4dee566867f8291f086">Minimal character-level language model with a Vanilla Recurrent Neural Network, in Python/numpy</a> + */ +public class RNN { + + public void learn(String text) { + + char[] textChars = text.toCharArray(); + List<Character> data = new LinkedList<>(); + for (char c : textChars) { + data.add(c); + } + Set<Character> chars = new HashSet<>(data); + int vocabSize = chars.size(); + System.out.printf("data has %d characters, %d unique.", data.size(), vocabSize); + Map<Character, Integer> charToIx = new HashMap<>(); + Map<Integer, Character> ixToChar = new HashMap<>(); + int i = 0; + for (Character c : chars) { + charToIx.put(c, i); + ixToChar.put(i, c); + i++; + } + + // hyperparameters + int hiddenSize = 40; // size of hidden layer of neurons + int seqLength = 10; // no. of steps to unroll the RNN for + float learningRate = 1e-2f; + + // model parameters + INDArray wxh = Nd4j.randn(hiddenSize, vocabSize).mul(0.001); // input to hidden + INDArray whh = Nd4j.randn(hiddenSize, hiddenSize).mul(0.001); // hidden to hidden + INDArray why = Nd4j.randn(vocabSize, hiddenSize).mul(0.001); // hidden to output + INDArray bh = Nd4j.zeros(hiddenSize, 1); // hidden bias + INDArray by = Nd4j.zeros(vocabSize, 1); // output bias + + int n = 0; + int p = 0; + + // memory variables for Adagrad + INDArray mWxh = Nd4j.zerosLike(wxh); + INDArray mWhh = Nd4j.zerosLike(whh); + INDArray mWhy = Nd4j.zerosLike(why); + + INDArray mbh = Nd4j.zerosLike(bh); + INDArray mby = Nd4j.zerosLike(by); + + // loss at iteration 0 + double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength; + + INDArray hPrev = null; + while (true) { + // prepare inputs (we're sweeping from left to right in steps seqLength long) + if (p + seqLength + 1 >= data.size() || n == 0) { + hPrev = Nd4j.zeros(hiddenSize, 1); // reset RNN memory + p = 0; // go from start of data + } + + INDArray inputs = Nd4j.create(seqLength); + int c = 0; + for (Character ch : data.subList(p, p + seqLength)) { + Integer ix = charToIx.get(ch); + inputs.putScalar(c, ix); + c++; + } + INDArray targets = Nd4j.create(seqLength); + c = 0; + for (Character ch : data.subList(p + 1, p + seqLength + 1)) { + Integer ix = charToIx.get(ch); + targets.putScalar(c, ix); + c++; + } + + // sample from the model now and then + if (n % 1000 == 0) { + sample(vocabSize, ixToChar, wxh, whh, why, bh, by, hPrev, inputs); + } + + INDArray dWxh = Nd4j.zerosLike(wxh); + INDArray dWhh = Nd4j.zerosLike(whh); + INDArray dWhy = Nd4j.zerosLike(why); + + INDArray dbh = Nd4j.zerosLike(bh); + INDArray dby = Nd4j.zerosLike(by); + + // forward seqLength characters through the net and fetch gradient + double loss = lossFun(vocabSize, wxh, whh, why, bh, by, hPrev, inputs, targets, dWxh, dWhh, dWhy, dbh, dby); + smoothLoss = smoothLoss * 0.99 + loss * 0.001; + if (n % 100 == 0) { + System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress + } + + // perform parameter update with Adagrad + mWxh.addi(dWxh.mul(dWxh)); + wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8)))); + + mWhh.addi(dWhh.mul(dWhh)); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8)))); + + mWhy.addi(dWhy.mul(dWhy)); + why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.addi(1e-8)))); + + mbh.addi(dbh.mul(dbh)); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8)))); + + mby.addi(dby.mul(dby)); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8)))); + + p += seqLength; // move data pointer + n++; // iteration counter + } + } + + /** + * inputs, targets are both list of integers + * hprev is Hx1 array of initial hidden state + * returns the loss, gradients on model parameters and last hidden state + */ + private double lossFun(int vocabSize, INDArray wxh, INDArray whh, INDArray why, INDArray bh, INDArray by, INDArray hPrev, + INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhy, INDArray dbh, + INDArray dby) { + + INDArray xs = Nd4j.zeros(inputs.length(), vocabSize); + INDArray hs = null; + INDArray ys = null; + INDArray ps = null; + + INDArray hs1 = Nd4j.create(hPrev.shape()); + Nd4j.copy(hPrev, hs1); + + double loss = 0; + + // forward pass + for (int t = 0; t < inputs.length(); t++) { + int tIndex = inputs.getScalar(t).getInt(0); + xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation + INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); + INDArray hst = Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh))); // hidden state + if (hs == null) { + hs = init(inputs.length(), hst); + } + hs.putRow(t, hst); + + INDArray yst = (why.mmul(hst)).add(by); // unnormalized log probabilities for next chars + if (ys == null) { + ys = init(inputs.length(), yst); + } + ys.putRow(t, yst); + INDArray exp = Transforms.exp(yst); + Number sumExp = exp.sumNumber(); + INDArray pst = exp.div(sumExp); // probabilities for next chars + if (ps == null) { + ps = init(inputs.length(), pst); + } + ps.putRow(t, pst); + loss += -Transforms.log(ps.getRow(t).getRow(targets.getInt(t)), true).sumNumber().doubleValue(); // softmax (cross-entropy loss) + } + + // backward pass: compute gradients going backwards + INDArray dhNext = Nd4j.zerosLike(hs.getRow(0)); + for (int t = inputs.length() - 1; t >= 0; t--) { + INDArray dy = ps.getRow(t).dup(); + dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y + INDArray hst = hs.getRow(t); + dWhy.addi(dy.mmul(hst.transpose())); + dby.addi(dy); + INDArray dh = why.transpose().mmul(dy).add(dhNext); // backprop into h + INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // backprop through tanh nonlinearity + dbh.addi(dhraw); + dWxh.addi(dhraw.mmul(xs.getRow(t))); + INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); + dWhh.addi(dhraw.mmul(hsRow.transpose())); + dhNext = whh.transpose().mmul(dhraw); + + } + // clip exploding gradients + Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dWhy, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -5, 5)); + + return loss; + } + + private INDArray init(int t, INDArray ast) { + INDArray as; + int[] aShape = ast.shape(); + int[] shape = new int[1 + aShape.length]; + shape[0] = t; + System.arraycopy(aShape, 0, shape, 1, aShape.length); + as = Nd4j.create(shape); + return as; + } + + /** + * sample a sequence of integers from the model, h is memory state, seed_ix is seed letter for first time step + */ + private void sample(int vocabSize, Map<Integer, Character> ixToChar, INDArray wxh, INDArray whh, INDArray why, + INDArray bh, INDArray by, INDArray hPrev, INDArray inputs) { + + INDArray x = Nd4j.zeros(vocabSize, 1); + int seedIx = inputs.getInt(0); + x.putScalar(seedIx, 1); + int sampleSize = 200; + INDArray ixes = Nd4j.create(sampleSize); + + for (int t = 0; t < sampleSize; t++) { + INDArray h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh))); + INDArray y = (why.mmul(h)).add(by); + INDArray exp = Transforms.exp(y); + INDArray pm = exp.div(Nd4j.sum(exp)).ravel(); + + List<Pair<Integer, Double>> d = new LinkedList<>(); + for (int pi = 0; pi < vocabSize; pi++) { + d.add(new Pair<>(pi, pm.getDouble(0, pi))); + } + EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d); + + int ix = distribution.sample(); + + x = Nd4j.zeros(vocabSize, 1); + x.putScalar(ix, 1); + ixes.putScalar(t, ix); + } + + String txt = ""; + + + NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape()); + while (ndIndexIterator.hasNext()) { + int[] next = ndIndexIterator.next(); + txt += ixToChar.get(ixes.getInt(next)); + } + System.out.printf("\n---\n %s \n----\n", txt); + } + +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1763464&r1=1763463&r2=1763464&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Wed Oct 5 16:12:24 2016 @@ -62,8 +62,8 @@ public class SkipGramNetwork { * the first row of weighs[0]Â matrix holds the weights of each neuron in the first neuron of the second layer, * the second row of weighs[0]Â holds the weights of each neuron in the second neuron of the second layer, etc. */ - private RealMatrix[] weights; - private RealMatrix[] biases; + private final RealMatrix[] weights; + private final RealMatrix[] biases; private Sample[] samples; @@ -184,7 +184,7 @@ public class SkipGramNetwork { * @return the output * @throws Exception */ - public double[] predictOutput(double[] input) throws Exception { + private double[] predictOutput(double[] input) throws Exception { RealMatrix hidden = rectifierFunction.applyMatrix(MatrixUtils.createRowRealMatrix(input).multiply(weights[0].transpose()). add(biases[0])); @@ -214,7 +214,7 @@ public class SkipGramNetwork { * @return the final cost with the updated weights * @throws Exception if BGD fails to converge or any numerical error happens */ - public double learnWeights(Sample... samples) throws Exception { + private double learnWeights(Sample... samples) throws Exception { int iterations = 0; @@ -250,7 +250,7 @@ public class SkipGramNetwork { } RealMatrix w0t = weights[0].transpose(); - final RealMatrix w1t = weights[1].transpose(); + RealMatrix w1t = weights[1].transpose(); RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(w0t)); hidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @@ -912,23 +912,23 @@ public class SkipGramNetwork { private static class Configuration { // internal parameters - protected int outputs; - protected int inputs; + int outputs; + int inputs; - protected List<String> vocabulary; + List<String> vocabulary; // user controlled parameters - protected Path path; - protected int maxIterations; - protected double alpha = 0.5d; - protected double mu = 0.9d; - protected double regularizationLambda = 0.03; - protected double threshold = 0.0000000000004d; - protected int vectorSize; - protected int window; - protected boolean useMomentum; - protected boolean useNesterovMomentum; - protected int batchSize; + Path path; + int maxIterations; + double alpha = 0.5d; + double mu = 0.9d; + double regularizationLambda = 0.03; + double threshold = 0.0000000000004d; + int vectorSize; + int window; + boolean useMomentum; + boolean useNesterovMomentum; + int batchSize; } public static class Builder { @@ -1039,7 +1039,7 @@ public class SkipGramNetwork { return vocabulary; } - private Collection<HotEncodedSample> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException { + private Collection<HotEncodedSample> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws Exception { long start = System.currentTimeMillis(); Collection<HotEncodedSample> samples = new LinkedList<>(); List<byte[]> fragment; @@ -1063,8 +1063,9 @@ public class SkipGramNetwork { String x = new String(inputWord); inputs[0] = (double) vocabulary.indexOf(x); - samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size())); -// System.err.println("added: " + x + " -> " + Arrays.toString(os.toArray())); + HotEncodedSample hotEncodedSample = new HotEncodedSample(inputs, doubles, vocabulary.size()); + samples.add(hotEncodedSample); +// System.err.println("added: " + x + " -> " + hotEncodedSample); } long end = System.currentTimeMillis(); Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java?rev=1763464&r1=1763463&r2=1763464&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java Wed Oct 5 16:12:24 2016 @@ -36,7 +36,7 @@ public class MultiLayerNetworkTest { @Test public void testLearnAndPredict() throws Exception { MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration(); - configuration.alpha = 0.0000001d; + configuration.alpha = 0.000000001d; configuration.layers = new int[]{3, 4, 1}; configuration.maxIterations = 1000000; configuration.threshold = 0.00000004d; Added: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java?rev=1763464&view=auto ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java (added) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java Wed Oct 5 16:12:24 2016 @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.io.IOUtils; +import org.junit.Test; + +import java.io.InputStream; + +/** + * Tests for {@link RNN} + */ +public class RNNTest { + + @Test + public void test() throws Exception { + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/abstracts.txt"); + String text = IOUtils.toString(resourceAsStream); + RNN n = new RNN(); + n.learn(text); + } + +} \ No newline at end of file Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1763464&r1=1763463&r2=1763464&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Wed Oct 5 16:12:24 2016 @@ -47,12 +47,33 @@ public class SkipGramNetworkTest { withWindow(3). fromTextAt(path). withDimension(10). - withAlpha(0.01). + withAlpha(0.1). withLambda(0.0001). useNesterovMomentum(true). withMu(0.9). withMaxIterations(30000). - withBatchSize(10). + withBatchSize(1). + build(); + RealMatrix wv = network.getWeights()[0]; + List<String> vocabulary = network.getVocabulary(); + serialize(vocabulary, wv); + System.err.println("accuracy: " + SkipGramNetwork.evaluate(network)); + measure(vocabulary, wv); + } + + @Test + public void testWordVectorsLearningOnBigText() throws Exception { + Path path = Paths.get(getClass().getResource("/word2vec/big.txt").getFile()); + SkipGramNetwork network = SkipGramNetwork.newModel(). + withWindow(3). + fromTextAt(path). + withDimension(2). + withAlpha(0.1). + withLambda(0.0001). + useNesterovMomentum(true). + withMu(0.9). + withMaxIterations(1000000). + withBatchSize(1). build(); RealMatrix wv = network.getWeights()[0]; List<String> vocabulary = network.getVocabulary(); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org