Repository: opennlp-sandbox Updated Branches: refs/heads/master 6bfb15f07 -> fe2b1d920
fixed adagrad update for (s)rnn, added rmsprop to srnn Project: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/repo Commit: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/commit/fe2b1d92 Tree: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/tree/fe2b1d92 Diff: http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/diff/fe2b1d92 Branch: refs/heads/master Commit: fe2b1d920512bfbf696863459a14e9df7533480c Parents: 6bfb15f Author: Tommaso Teofili <[email protected]> Authored: Sun May 28 08:56:55 2017 +0200 Committer: Tommaso Teofili <[email protected]> Committed: Sun May 28 08:56:55 2017 +0200 ---------------------------------------------------------------------- .../src/main/java/opennlp/tools/dl/RNN.java | 10 +-- .../main/java/opennlp/tools/dl/StackedRNN.java | 71 ++++++++++++++------ .../src/test/java/opennlp/tools/dl/RNNTest.java | 1 + .../java/opennlp/tools/dl/StackedRNNTest.java | 5 +- 4 files changed, 60 insertions(+), 27 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/fe2b1d92/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java index 2fabecd..417b98c 100644 --- a/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java +++ b/opennlp-dl/src/main/java/opennlp/tools/dl/RNN.java @@ -175,19 +175,19 @@ public class RNN { // perform parameter update with Adagrad mWxh.addi(dWxh.mul(dWxh)); - wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg)))); + wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh).add(reg))); mWhh.addi(dWhh.mul(dWhh)); - whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(reg))); mWhy.addi(dWhy.mul(dWhy)); - why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg)))); + why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy).add(reg))); mbh.addi(dbh.mul(dbh)); - bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg)))); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(reg))); mby.addi(dby.mul(dby)); - by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg)))); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(reg))); } p += seqLength; // move data pointer http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/fe2b1d92/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java index e6ceb9b..889fac1 100644 --- a/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java +++ b/opennlp-dl/src/main/java/opennlp/tools/dl/StackedRNN.java @@ -54,18 +54,21 @@ public class StackedRNN extends RNN { private final INDArray bh2; // hidden2 bias private final INDArray by; // output bias - private final double reg = 1e-8; + private final double eps = 1e-4; + private final double decay = 0.9; + private final boolean rmsProp; private INDArray hPrev = null; // memory state private INDArray hPrev2 = null; // memory state public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text) { - this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true); + this(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, true, false); } - public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars) { + public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, int batch, boolean useChars, boolean rmsProp) { super(learningRate, seqLength, hiddenLayerSize, epochs, text, batch, useChars); + this.rmsProp = rmsProp; wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize)); whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); @@ -141,30 +144,58 @@ public class StackedRNN extends RNN { } if (n % batch == 0) { - // perform parameter update with Adagrad - mWxh.addi(dWxh.mul(dWxh)); - wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg)))); + if (rmsProp) { + // perform parameter update with RMSprop + mWxh = mWxh.mul(decay).add(1 - decay).mul((dWxh).mul(dWxh)); + wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps))); - mWxh2.addi(dWxh2.mul(dWxh2)); - wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2.add(reg)))); + mWxh2 = mWxh2.mul(decay).add(1 - decay).mul((dWxh2).mul(dWxh2)); + wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps))); - mWhh.addi(dWhh.mul(dWhh)); - whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); + mWhh = mWhh.mul(decay).add(1 - decay).mul((dWhh).mul(dWhh)); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps))); - mWhh2.addi(dWhh2.mul(dWhh2)); - whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2.add(reg)))); + mWhh2 = mWhh2.mul(decay).add(1 - decay).mul((dWhh2).mul(dWhh2)); + whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps))); - mbh2.addi(dbh2.mul(dbh2)); - bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2.add(reg)))); + mbh2 = mbh2.mul(decay).add(1 - decay).mul((dbh2).mul(dbh2)); + bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps))); - mWh2y.addi(dWh2y.mul(dWh2y)); - wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y.add(reg)))); + mWh2y = mWh2y.mul(decay).add(1 - decay).mul((dWh2y).mul(dWh2y)); + wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps))); - mbh.addi(dbh.mul(dbh)); - bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg)))); + mbh = mbh.mul(decay).add(1 - decay).mul((dbh).mul(dbh)); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps))); - mby.addi(dby.mul(dby)); - by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg)))); + mby = mby.mul(decay).add(1 - decay).mul((dby).mul(dby)); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps))); + } else { + // perform parameter update with Adagrad + + mWxh.addi(dWxh.mul(dWxh)); + wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh).add(eps))); + + mWxh2.addi(dWxh2.mul(dWxh2)); + wxh2.subi(dWxh2.mul(learningRate).div(Transforms.sqrt(mWxh2).add(eps))); + + mWhh.addi(dWhh.mul(dWhh)); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh).add(eps))); + + mWhh2.addi(dWhh2.mul(dWhh2)); + whh2.subi(dWhh2.mul(learningRate).div(Transforms.sqrt(mWhh2).add(eps))); + + mbh2.addi(dbh2.mul(dbh2)); + bh2.subi(dbh2.mul(learningRate).div(Transforms.sqrt(mbh2).add(eps))); + + mWh2y.addi(dWh2y.mul(dWh2y)); + wh2y.subi(dWh2y.mul(learningRate).div(Transforms.sqrt(mWh2y).add(eps))); + + mbh.addi(dbh.mul(dbh)); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh).add(eps))); + + mby.addi(dby.mul(dby)); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby).add(eps))); + } } p += seqLength; // move data pointer http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/fe2b1d92/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java index 57f7682..88a9413 100644 --- a/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java +++ b/opennlp-dl/src/test/java/opennlp/tools/dl/RNNTest.java @@ -18,6 +18,7 @@ */ package opennlp.tools.dl; +import java.io.FileInputStream; import java.io.InputStream; import java.util.Arrays; import java.util.Collection; http://git-wip-us.apache.org/repos/asf/opennlp-sandbox/blob/fe2b1d92/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java ---------------------------------------------------------------------- diff --git a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java index 686d603..265426f 100644 --- a/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java +++ b/opennlp-dl/src/test/java/opennlp/tools/dl/StackedRNNTest.java @@ -18,6 +18,7 @@ */ package opennlp.tools.dl; +import java.io.FileInputStream; import java.io.InputStream; import java.util.Arrays; import java.util.Collection; @@ -69,14 +70,14 @@ public class StackedRNNTest { @Test public void testStackedCharRNNLearn() throws Exception { - RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true); + RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 5, true, true); evaluate(rnn, true); rnn.serialize("target/scrnn-weights-"); } @Test public void testStackedWordRNNLearn() throws Exception { - RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false); + RNN rnn = new StackedRNN(learningRate, seqLength, hiddenLayerSize, epochs, text, 1, false, false); evaluate(rnn, true); rnn.serialize("target/swrnn-weights-"); }
