Author: tommaso Date: Fri Oct 14 11:57:34 2016 New Revision: 1764880 URL: http://svn.apache.org/viewvc?rev=1764880&view=rev Log: fixed wrong adagrad update, more cli params for nn runner
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1764880&r1=1764879&r2=1764880&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Fri Oct 14 11:57:34 2016 @@ -41,7 +41,7 @@ public class NNRunner { case "recurrent": { // recurrent neural network // e.g. bin/nn recurrent core/src/test/resources/word2vec/sentences.txt true 100 25 100 stacked - float learningRate = 1e-2f; + float learningRate = 1e-1f; int seqLength = 25; int hiddenLayerSize = 30; int epochs = 20; @@ -69,9 +69,13 @@ public class NNRunner { if (args.length > 5 && args[5] != null) { seqLength = Integer.valueOf(args[5]); } - RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); if (args.length > 6 && args[6] != null) { - if ("stacked".equals(args[6])) { + learningRate = Float.valueOf(args[6]); + } + + RNN rnn = new RNN(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); + if (args.length > 7 && args[7] != null) { + if ("stacked".equals(args[7])) { rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); } } @@ -79,6 +83,11 @@ public class NNRunner { rnn.learn(); int seed = random.nextInt(rnn.vocabSize); System.out.println(rnn.sample(seed)); + try { + rnn.serialize("weights.txt"); + } catch (IOException e) { + throw new RuntimeException("cannot serialize weights", e); + } break; } case "multi": { Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1764880&r1=1764879&r2=1764880&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Fri Oct 14 11:57:34 2016 @@ -170,19 +170,19 @@ public class RNN { // perform parameter update with Adagrad mWxh.addi(dWxh.mul(dWxh)); - wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8)))); + wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(1e-8)))); mWhh.addi(dWhh.mul(dWhh)); - whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8)))); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(1e-8)))); mWhy.addi(dWhy.mul(dWhy)); - why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.addi(1e-8)))); + why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(1e-8)))); mbh.addi(dbh.mul(dbh)); - bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8)))); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(1e-8)))); mby.addi(dby.mul(dby)); - by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8)))); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(1e-8)))); p += seqLength; // move data pointer n++; // iteration counter Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1764880&r1=1764879&r2=1764880&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Oct 14 11:57:34 2016 @@ -22,6 +22,7 @@ import org.apache.commons.math3.distribu import org.apache.commons.math3.util.Pair; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.SetRange; +import org.nd4j.linalg.api.ops.impl.transforms.SoftMax; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.ops.transforms.Transforms; @@ -74,7 +75,7 @@ public class StackedRNN extends RNN { public void learn() { - int currentEpoch = 0; + int currentEpoch = -1; int n = 0; int p = 0; @@ -136,25 +137,25 @@ public class StackedRNN extends RNN { // perform parameter update with Adagrad mWxh.addi(dWxh.mul(dWxh)); - wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8)))); + wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(1e-8)))); mWhh.addi(dWhh.mul(dWhh)); - whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8)))); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(1e-8)))); mWhh2.addi(dWhh2.mul(dWhh2)); - whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.addi(1e-8)))); + whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(1e-8)))); mbh2.addi(dbh2.mul(dbh2)); - bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.addi(1e-8)))); + bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(1e-8)))); mWh2y.addi(dWh2y.mul(dWh2y)); - wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.addi(1e-8)))); + wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(1e-8)))); mbh.addi(dbh.mul(dbh)); - bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8)))); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(1e-8)))); mby.addi(dby.mul(dby)); - by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8)))); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(1e-8)))); p += seqLength; // move data pointer n++; // iteration counter @@ -190,6 +191,7 @@ public class StackedRNN extends RNN { INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); INDArray hst = Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh))); // hidden state +// INDArray hst = Transforms.relu((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh))); // hidden state if (hs == null) { hs = init(inputs.length(), hst); } @@ -197,6 +199,7 @@ public class StackedRNN extends RNN { INDArray hs2Row = t == 0 ? hs12 : hs2.getRow(t - 1); INDArray hst2 = Transforms.tanh((whh.mmul(hs.getRow(t))).add((whh2.mmul(hs2Row)).add(bh2))); // hidden state 2 +// INDArray hst2 = Transforms.relu((whh.mmul(hs.getRow(t))).add((whh2.mmul(hs2Row)).add(bh2))); // hidden state 2 if (hs2 == null) { hs2 = init(inputs.length(), hst2); } @@ -207,9 +210,11 @@ public class StackedRNN extends RNN { ys = init(inputs.length(), yst); } ys.putRow(t, yst); - INDArray exp = Transforms.exp(yst); - Number sumExp = exp.sumNumber(); - INDArray pst = exp.div(sumExp); // probabilities for next chars + +// INDArray exp = Transforms.exp(yst); +// Number sumExp = exp.sumNumber(); +// INDArray pst = exp.div(sumExp); // probabilities for next chars + INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); if (ps == null) { ps = init(inputs.length(), pst); } @@ -272,7 +277,9 @@ public class StackedRNN extends RNN { for (int t = 0; t < sampleSize; t++) { INDArray h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh))); +// INDArray h = Transforms.relu((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh))); INDArray h2 = Transforms.tanh((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2))); +// INDArray h2 = Transforms.relu((whh.mmul(h)).add((whh2.mmul(hPrev2)).add(bh2))); INDArray y = (wh2y.mmul(h2)).add(by); INDArray exp = Transforms.exp(y); INDArray pm = exp.div(Nd4j.sum(exp)).ravel(); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org