Author: tommaso Date: Wed Oct 12 12:36:05 2016 New Revision: 1764448 URL: http://svn.apache.org/viewvc?rev=1764448&view=rev Log: added sRNN
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java (with props) Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java?rev=1764448&r1=1764447&r2=1764448&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/CharRNN.java Wed Oct 12 12:36:05 2016 @@ -43,14 +43,15 @@ import java.util.Set; public class CharRNN { // hyperparameters - private final float learningRate; // size of hidden layer of neurons - private final int seqLength; // no. of steps to unroll the RNN for - private final int hiddenLayerSize; - private final int epochs; - private final int vocabSize; - private final Map<Character, Integer> charToIx; - private final Map<Integer, Character> ixToChar; - private final List<Character> data; + protected final float learningRate; // size of hidden layer of neurons + protected final int seqLength; // no. of steps to unroll the RNN for + protected final int hiddenLayerSize; + protected final int epochs; + private final boolean useChars; + protected final int vocabSize; + protected final Map<String, Integer> charToIx; + protected final Map<Integer, String> ixToChar; + protected final List<String> data; // model parameters private final INDArray wxh; // input to hidden @@ -62,22 +63,29 @@ public class CharRNN { private INDArray hPrev = null; // memory state public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { + this(learningRate, seqLength, hiddenLayerSize, epochs, text, true); + } + + public CharRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { this.learningRate = learningRate; this.seqLength = seqLength; this.hiddenLayerSize = hiddenLayerSize; this.epochs = epochs; - char[] textChars = text.toCharArray(); + this.useChars = useChars; + + String[] textTokens = useChars ? toStrings(text.toCharArray()) : text.split(" "); data = new LinkedList<>(); - for (char c : textChars) { + for (String c : textTokens) { data.add(c); } - Set<Character> chars = new HashSet<>(data); - vocabSize = chars.size(); - System.out.printf("data has %d characters, %d unique.", data.size(), vocabSize); + Set<String> tokens = new HashSet<>(data); + vocabSize = tokens.size(); + + System.out.printf("data has %d tokens, %d unique.", data.size(), vocabSize); charToIx = new HashMap<>(); ixToChar = new HashMap<>(); int i = 0; - for (Character c : chars) { + for (String c : tokens) { charToIx.put(c, i); ixToChar.put(i, c); i++; @@ -90,6 +98,14 @@ public class CharRNN { by = Nd4j.zeros(vocabSize, 1).mul(0.01); } + private String[] toStrings(char[] chars) { + String[] strings = new String[chars.length]; + for (int i = 0; i < chars.length; i++) { + strings[i] = String.valueOf(chars[i]); + } + return strings; + } + public void learn() { int currentEpoch = 0; @@ -137,7 +153,7 @@ public class CharRNN { INDArray dby = Nd4j.zerosLike(by); // forward seqLength characters through the net and fetch gradient - double loss = lossFun(vocabSize, wxh, whh, why, bh, by, hPrev, inputs, targets, dWxh, dWhh, dWhy, dbh, dby); + double loss = lossFun(inputs, targets, dWxh, dWhh, dWhy, dbh, dby); smoothLoss = smoothLoss * 0.99 + loss * 0.001; if (Double.isNaN(smoothLoss)) { System.out.println("loss is NaN (over/underflow occured, try adjusting hyperparameters)"); @@ -168,10 +184,10 @@ public class CharRNN { } } - private INDArray getSequence(int p) { + protected INDArray getSequence(int p) { INDArray inputs = Nd4j.create(seqLength); int c = 0; - for (Character ch : data.subList(p, p + seqLength)) { + for (String ch : data.subList(p, p + seqLength)) { Integer ix = charToIx.get(ch); inputs.putScalar(c, ix); c++; @@ -184,8 +200,7 @@ public class CharRNN { * hprev is Hx1 array of initial hidden state * returns the loss, gradients on model parameters and last hidden state */ - private double lossFun(int vocabSize, INDArray wxh, INDArray whh, INDArray why, INDArray bh, INDArray by, INDArray hPrev, - INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhy, INDArray dbh, + private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhy, INDArray dbh, INDArray dby) { INDArray xs = Nd4j.zeros(inputs.length(), vocabSize); @@ -253,7 +268,7 @@ public class CharRNN { return loss; } - private INDArray init(int t, INDArray ast) { + protected INDArray init(int t, INDArray ast) { INDArray as; int[] aShape = ast.shape(); int[] shape = new int[1 + aShape.length]; @@ -306,4 +321,26 @@ public class CharRNN { return vocabSize; } + @Override + public String toString() { + return "CharRNN{" + + "learningRate=" + learningRate + + ", seqLength=" + seqLength + + ", hiddenLayerSize=" + hiddenLayerSize + + ", epochs=" + epochs + + ", vocabSize=" + vocabSize + + ", useChars=" + useChars + + '}'; + } + + + public String getHyperparamsString() { + return "CharRNN{" + + ", wxh=" + wxh + + ", whh=" + whh + + ", why=" + why + + ", bh=" + bh + + ", by=" + by + + '}'; + } } Added: labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java?rev=1764448&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java Wed Oct 12 12:36:05 2016 @@ -0,0 +1,296 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.math3.distribution.EnumeratedDistribution; +import org.apache.commons.math3.util.Pair; +import org.nd4j.linalg.api.iter.NdIndexIterator; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.impl.transforms.SetRange; +import org.nd4j.linalg.factory.Nd4j; +import org.nd4j.linalg.ops.transforms.Transforms; + +import java.util.LinkedList; +import java.util.List; + +/** + * A basic char-level stacked RNN model (2 hidden recurrent layers), based on Stacked RNN architecture from ICLR 2014's + * "How to Construct Deep Recurrent Neural Networks" by Razvan Pascanu, Caglar Gulcehre, Kyunghyun Cho and Yoshua Bengio + * and Andrej Karpathy's notes on RNNs. + * See also: + * + * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness">The Unreasonable Effectiveness of Recurrent Neural Networks</a> + * @see <a href="https://arxiv.org/abs/1312.6026">How to Construct Deep Recurrent Neural Networks</a> + */ +public class CharStackedRNN extends CharRNN { + + // model parameters + private final INDArray wxh; // input to hidden + private final INDArray whh; // hidden to hidden + private final INDArray whh2; // hidden to hidden2 + private final INDArray wh2y; // hidden2 to output + private final INDArray bh; // hidden bias + private final INDArray bh2; // hidden2 bias + private final INDArray by; // output bias + + private INDArray hPrev = null; // memory state + private INDArray hPrev2 = null; // memory state + + public CharStackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { + this(learningRate, seqLength, hiddenLayerSize, epochs, text, true); + } + + public CharStackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { + super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); + + wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01); + whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01); + whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01); + wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).mul(0.01); + bh = Nd4j.zeros(hiddenLayerSize, 1).mul(0.01); + bh2 = Nd4j.zeros(hiddenLayerSize, 1).mul(0.01); + by = Nd4j.zeros(vocabSize, 1).mul(0.01); + } + + public void learn() { + + int currentEpoch = 0; + + int n = 0; + int p = 0; + + // memory variables for Adagrad + INDArray mWxh = Nd4j.zerosLike(wxh); + INDArray mWhh = Nd4j.zerosLike(whh); + INDArray mWhh2 = Nd4j.zerosLike(whh2); + INDArray mWh2y = Nd4j.zerosLike(wh2y); + + INDArray mbh = Nd4j.zerosLike(bh); + INDArray mbh2 = Nd4j.zerosLike(bh2); + INDArray mby = Nd4j.zerosLike(by); + + // loss at iteration 0 + double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength; + + while (true) { + // prepare inputs (we're sweeping from left to right in steps seqLength long) + if (p + seqLength + 1 >= data.size() || n == 0) { + hPrev = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory + hPrev2 = Nd4j.zeros(hiddenLayerSize, 1); // reset RNN memory + p = 0; // go from start of data + currentEpoch++; + if (currentEpoch == epochs) { + System.out.println("training finished: e:" + epochs + ", l: " + smoothLoss + ", h:(" + learningRate + ", " + seqLength + ", " + hiddenLayerSize + ")"); + break; + } + } + + INDArray inputs = getSequence(p); + INDArray targets = getSequence(p + 1); + + // sample from the model every now and then + if (n % 1000 == 0) { + String txt = sample(inputs.getInt(0)); + System.out.printf("\n---\n %s \n----\n", txt); + } + + INDArray dWxh = Nd4j.zerosLike(wxh); + INDArray dWhh = Nd4j.zerosLike(whh); + INDArray dWhh2 = Nd4j.zerosLike(whh2); + INDArray dWh2y = Nd4j.zerosLike(wh2y); + + INDArray dbh = Nd4j.zerosLike(bh); + INDArray dbh2 = Nd4j.zerosLike(bh); + INDArray dby = Nd4j.zerosLike(by); + + // forward seqLength characters through the net and fetch gradient + double loss = lossFun(inputs, targets, dWxh, dWhh, dWhh2, dWh2y, dbh, dbh2, dby); + smoothLoss = smoothLoss * 0.99 + loss * 0.001; + if (Double.isNaN(smoothLoss)) { + System.out.println("loss is NaN (over/underflow occured, try adjusting hyperparameters)"); + break; + } + if (n % 1000 == 0) { + System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress + } + + // perform parameter update with Adagrad + mWxh.addi(dWxh.mul(dWxh)); + wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8)))); + + mWhh.addi(dWhh.mul(dWhh)); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8)))); + + mWhh2.addi(dWhh2.mul(dWhh2)); + whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.addi(1e-8)))); + + mbh2.addi(dbh2.mul(dbh2)); + bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.addi(1e-8)))); + + mWh2y.addi(dWh2y.mul(dWh2y)); + wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.addi(1e-8)))); + + mbh.addi(dbh.mul(dbh)); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8)))); + + mby.addi(dby.mul(dby)); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8)))); + + p += seqLength; // move data pointer + n++; // iteration counter + } + } + + /** + * inputs, targets are both list of integers + * hprev is Hx1 array of initial hidden state + * returns the loss, gradients on model parameters and last hidden state + */ + private double lossFun(INDArray inputs, INDArray targets, INDArray dWxh, INDArray dWhh, INDArray dWhh2, INDArray dWh2y, + INDArray dbh, INDArray dbh2, INDArray dby) { + + INDArray xs = Nd4j.zeros(inputs.length(), vocabSize); + INDArray hs = null; + INDArray hs2 = null; + INDArray ys = null; + INDArray ps = null; + + INDArray hs1 = Nd4j.create(hPrev.shape()); + Nd4j.copy(hPrev, hs1); + + INDArray hs12 = Nd4j.create(hPrev2.shape()); + Nd4j.copy(hPrev2, hs1); + + double loss = 0; + + // forward pass + for (int t = 0; t < inputs.length(); t++) { + int tIndex = inputs.getScalar(t).getInt(0); + xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation + + INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); + INDArray hst = Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh))); // hidden state + if (hs == null) { + hs = init(inputs.length(), hst); + } + hs.putRow(t, hst); + + INDArray hs2Row = t == 0 ? hs12 : hs2.getRow(t - 1); + INDArray hst2 = Transforms.tanh((whh.mmul(hs.getRow(t))).add((whh2.mmul(hs2Row)).add(bh2))); // hidden state 2 + if (hs2 == null) { + hs2 = init(inputs.length(), hst2); + } + hs.putRow(t, hst); + + INDArray yst = (wh2y.mmul(hst)).add(by); // unnormalized log probabilities for next chars + if (ys == null) { + ys = init(inputs.length(), yst); + } + ys.putRow(t, yst); + INDArray exp = Transforms.exp(yst); + Number sumExp = exp.sumNumber(); + INDArray pst = exp.div(sumExp); // probabilities for next chars + if (ps == null) { + ps = init(inputs.length(), pst); + } + ps.putRow(t, pst); + loss += -Transforms.log(ps.getRow(t).getRow(targets.getInt(t)), true).sumNumber().doubleValue(); // softmax (cross-entropy loss) + } + + this.hPrev = hs.getRow(inputs.length() - 1); + + // backward pass: compute gradients going backwards + INDArray dhNext = Nd4j.zerosLike(hs.getRow(0)); + INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0)); + for (int t = inputs.length() - 1; t >= 0; t--) { + INDArray dy = ps.getRow(t).dup(); + dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // backprop into y + + INDArray hst2 = hs2.getRow(t); + dWh2y.addi(dy.mmul(hst2.transpose())); + dby.addi(dy); + INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // backprop into h2 + INDArray dhraw2 = (Nd4j.ones(hst2.shape()).sub(hst2).mul(hst2)).mul(dh2); // backprop through tanh nonlinearity + dbh2.addi(dhraw2); + + INDArray hst = hs.getRow(t); + dWhh2.addi(dh2.mmul(hst.transpose())); + dbh.addi(dh2); + INDArray dh = whh2.transpose().mmul(dh2).add(dhNext); // backprop into h + INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // backprop through tanh nonlinearity + dbh.addi(dhraw); + + dWxh.addi(dhraw.mmul(xs.getRow(t))); + INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); + dWhh.addi(dhraw.mmul(hsRow.transpose())); + dhNext = whh.transpose().mmul(dhraw); + + } + // clip exploding gradients + Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh2, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dWh2y, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -5, 5)); + Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -5, 5)); + + return loss; + } + + /** + * sample a sequence of integers from the model, using current (hPrev) memory state, seedIx is seed letter for first time step + */ + public String sample(int seedIx) { + + INDArray x = Nd4j.zeros(vocabSize, 1); + x.putScalar(seedIx, 1); + int sampleSize = 200; + INDArray ixes = Nd4j.create(sampleSize); + + for (int t = 0; t < sampleSize; t++) { + INDArray h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh))); + INDArray h2 = Transforms.tanh((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2))); + INDArray y = (wh2y.mmul(h2)).add(by); + INDArray exp = Transforms.exp(y); + INDArray pm = exp.div(Nd4j.sum(exp)).ravel(); + + List<Pair<Integer, Double>> d = new LinkedList<>(); + for (int pi = 0; pi < vocabSize; pi++) { + d.add(new Pair<>(pi, pm.getDouble(0, pi))); + } + EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d); + + int ix = distribution.sample(); + + x = Nd4j.zeros(vocabSize, 1); + x.putScalar(ix, 1); + ixes.putScalar(t, ix); + } + + StringBuilder txt = new StringBuilder(); + + NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape()); + while (ndIndexIterator.hasNext()) { + int[] next = ndIndexIterator.next(); + txt.append(ixToChar.get(ixes.getInt(next))); + } + return txt.toString(); + } + +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/CharStackedRNN.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java?rev=1764448&r1=1764447&r2=1764448&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/CharRNNCrossValidationTest.java Wed Oct 12 12:36:05 2016 @@ -49,21 +49,53 @@ public class CharRNNCrossValidationTest @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][]{ - {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 250, 512}, - {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 100, 30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15}, - {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 100, 256}, {1e-2f, 100, 512}, {1e-2f, 100, 128}, - {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 100, 256}, - {1e-3f, 100, 100}, +// {1e-1f, 100, 25}, {1e-1f, 200, 512}, {5e-1f, 25, 25}, {1e-1f, 250, 512}, +// {5e-1f, 25, 100}, {5e-1f, 200, 50}, {5e-1f, 200, 40}, {5e-1f, 100, 30}, {5e-1f, 100, 20}, {5e-1f, 250, 20}, {5e-1f, 250, 15}, +// {5e-2f, 50, 64}, {3e-2f, 50, 128}, {5e-2f, 100, 128}, {5e-2f, 100, 256}, {5e-2f, 100, 512}, {5e-2f, 100, 128}, +// {5e-3f, 100, 256}, {5e-3f, 100, 512}, {5e-4f, 100, 128}, {5e-4f, 100, 256}, +// {5e-3f, 100, 100}, {5e-2f, 50, 100} + {4e-1f, 100, 10} }); } @Test - public void testLearnWithDifferentHyperparameters() throws Exception { - System.out.println("hyperparameters: " + learningRate + ", " + seqLength + ", " + hiddenLayerSize); - InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/abstracts.txt"); + public void testStackedCharRNNLearn() throws Exception { + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); String text = IOUtils.toString(resourceAsStream); - int epochs = 1000000; + int epochs = 100; + CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + checkCorrectWordsRatio(text, charRNN); + } + + @Test + public void testStackedWordRNNLearn() throws Exception { + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); + String text = IOUtils.toString(resourceAsStream); + int epochs = 100; + CharRNN charRNN = new CharStackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); + checkCorrectWordsRatio(text, charRNN); + } + + @Test + public void testVanillaWordRNNLearn() throws Exception { + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); + String text = IOUtils.toString(resourceAsStream); + int epochs = 100; + CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, false); + checkCorrectWordsRatio(text, charRNN); + } + + @Test + public void testVanillaCharRNNLearn() throws Exception { + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); + String text = IOUtils.toString(resourceAsStream); + int epochs = 100; CharRNN charRNN = new CharRNN(learningRate, seqLength, hiddenLayerSize, epochs, text); + checkCorrectWordsRatio(text, charRNN); + } + + private void checkCorrectWordsRatio(String text, CharRNN charRNN) { + System.out.println(charRNN); List<String> words = Arrays.asList(text.split(" ")); charRNN.learn(); for (int i = 0; i < 10; i++) { @@ -79,6 +111,7 @@ public class CharRNNCrossValidationTest c /= sample.length(); } System.out.println("correct word ratio: " + c); + } } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java?rev=1764448&r1=1764447&r2=1764448&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/WordRNNCrossValidationTest.java Wed Oct 12 12:36:05 2016 @@ -49,11 +49,12 @@ public class WordRNNCrossValidationTest @Parameterized.Parameters public static Collection<Object[]> data() { return Arrays.asList(new Object[][]{ - {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 250, 512}, - {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 100, 30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15}, - {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 100, 256}, {1e-2f, 100, 512}, {1e-2f, 100, 128}, - {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 100, 256}, - {2e-1f, 25, 100}, +// {1e-1f, 100, 25}, {1e-1f, 200, 512}, {1e-1f, 25, 25}, {1e-1f, 250, 512}, +// {1e-1f, 25, 100}, {1e-1f, 200, 50}, {1e-1f, 200, 40}, {1e-1f, 100, 30}, {1e-1f, 100, 20}, {1e-1f, 250, 20}, {1e-1f, 250, 15}, +// {1e-2f, 50, 64}, {3e-2f, 50, 128}, {1e-2f, 100, 128}, {1e-2f, 100, 256}, {1e-2f, 100, 512}, {1e-2f, 100, 128}, +// {1e-3f, 100, 256}, {1e-3f, 100, 512}, {1e-4f, 100, 128}, {1e-4f, 100, 256}, +// {2e-1f, 25, 100}, + {5e-2f, 50, 64} }); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org