Author: tommaso Date: Fri Mar 3 10:27:37 2017 New Revision: 1785257 URL: http://svn.apache.org/viewvc?rev=1785257&view=rev Log: minor tweaks
Modified: labs/yay/trunk/core/pom.xml labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Modified: labs/yay/trunk/core/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1785257&r1=1785256&r2=1785257&view=diff ============================================================================== --- labs/yay/trunk/core/pom.xml (original) +++ labs/yay/trunk/core/pom.xml Fri Mar 3 10:27:37 2017 @@ -26,7 +26,7 @@ <relativePath>../</relativePath> </parent> <properties> - <nd4j.version>0.6.0</nd4j.version> + <dl4j.version>0.7.2</dl4j.version> </properties> <name>Yay core</name> <dependencies> @@ -58,7 +58,12 @@ <dependency> <groupId>org.nd4j</groupId> <artifactId>nd4j-native-platform</artifactId> - <version>${nd4j.version}</version> + <version>${dl4j.version}</version> + </dependency> + <dependency> + <groupId>org.deeplearning4j</groupId> + <artifactId>deeplearning4j-nlp</artifactId> + <version>${dl4j.version}</version> </dependency> </dependencies> Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1785257&r1=1785256&r2=1785257&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Fri Mar 3 10:27:37 2017 @@ -226,25 +226,23 @@ public class RNN { INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); INDArray hst = Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh))); // hidden state if (hs == null) { - hs = init(inputs.length(), hst); + hs = init(inputs.length(), hst.shape()); } hs.putRow(t, hst); INDArray yst = (why.mmul(hst)).add(by); // unnormalized log probabilities for next chars if (ys == null) { - ys = init(inputs.length(), yst); + ys = init(inputs.length(), yst.shape()); } ys.putRow(t, yst); INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars if (ps == null) { - ps = init(inputs.length(), pst); + ps = init(inputs.length(), pst.shape()); } ps.putRow(t, pst); loss += -Math.log(pst.getDouble(targets.getInt(t))); // softmax (cross-entropy loss) } - this.hPrev = hs.getRow(inputs.length() - 1); - // backward pass: compute gradients going backwards INDArray dhNext = Nd4j.zerosLike(hs.getRow(0)); for (int t = inputs.length() - 1; t >= 0; t--) { @@ -269,12 +267,13 @@ public class RNN { Nd4j.getExecutioner().exec(new SetRange(dbh, -5, 5)); Nd4j.getExecutioner().exec(new SetRange(dby, -5, 5)); + this.hPrev = hs.getRow(inputs.length() - 1); + return loss; } - protected INDArray init(int t, INDArray ast) { + protected INDArray init(int t, int[] aShape) { INDArray as; - int[] aShape = ast.shape(); int[] shape = new int[1 + aShape.length]; shape[0] = t; System.arraycopy(aShape, 0, shape, 1, aShape.length); @@ -292,9 +291,11 @@ public class RNN { int sampleSize = 2 * seqLength; INDArray ixes = Nd4j.create(sampleSize); + INDArray h = hPrev.dup(); + for (int t = 0; t < sampleSize; t++) { - hPrev = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh))); - INDArray y = (why.mmul(hPrev)).add(by); + h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(h)).add(bh))); + INDArray y = (why.mmul(h)).add(by); INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel(); List<Pair<Integer, Double>> d = new LinkedList<>(); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1785257&r1=1785256&r2=1785257&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Fri Mar 3 10:27:37 2017 @@ -66,10 +66,10 @@ public class StackedRNN extends RNN { public StackedRNN(float learningRate, int seqLength, int hiddenLayerSize, int epochs, String text, boolean useChars) { super(learningRate, seqLength, hiddenLayerSize, epochs, text, useChars); - wxh = Nd4j.randn(hiddenLayerSize, vocabSize).mul(0.01); - whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01); - whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).mul(0.01); - wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).mul(0.01); + wxh = Nd4j.randn(hiddenLayerSize, vocabSize).div(Math.sqrt(hiddenLayerSize)); + whh = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); + whh2 = Nd4j.randn(hiddenLayerSize, hiddenLayerSize).div(Math.sqrt(hiddenLayerSize)); + wh2y = Nd4j.randn(vocabSize, hiddenLayerSize).div(Math.sqrt(vocabSize)); bh = Nd4j.zeros(hiddenLayerSize, 1); bh2 = Nd4j.zeros(hiddenLayerSize, 1); by = Nd4j.zeros(vocabSize, 1); @@ -129,8 +129,8 @@ public class StackedRNN extends RNN { // forward seqLength characters through the net and fetch gradient double loss = lossFun(inputs, targets, dWxh, dWhh, dWhh2, dWh2y, dbh, dbh2, dby); smoothLoss = smoothLoss * 0.999 + loss * 0.001; - if (Double.isNaN(smoothLoss)) { - System.out.println("loss is NaN (over/underflow occured, try adjusting hyperparameters)"); + if (Double.isNaN(smoothLoss) || Double.isInfinite(smoothLoss)) { + System.out.println("loss is " + smoothLoss + " (over/underflow occured, try adjusting hyperparameters)"); break; } if (n % 100 == 0) { @@ -139,7 +139,7 @@ public class StackedRNN extends RNN { // perform parameter update with Adagrad mWxh.addi(dWxh.mul(dWxh)); - wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg)))); + wxh.subi(dWxh.mul(learningRate).div(Transforms.sqrt(mWxh.add(reg)))); mWhh.addi(dWhh.mul(dWhh)); whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); @@ -185,30 +185,29 @@ public class StackedRNN extends RNN { int tIndex = inputs.getScalar(t).getInt(0); xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation - INDArray hsRow = t == 0 ? hPrev : hs.getRow(t - 1); INDArray xst = xs.getRow(t); - INDArray hst = Transforms.tanh((wxh.mmul(xst.transpose())).add((whh.mmul(hsRow)).add(bh))); // hidden state + + hPrev = Transforms.tanh((wxh.mmul(xst.transpose()).add(whh.mmul(hPrev)).add(bh))); // hidden state if (hs == null) { - hs = init(seqLength, hst); + hs = init(seqLength, hPrev.shape()); } - hs.putRow(t, hst); + hs.putRow(t, hPrev.dup()); - INDArray hs2Row = t == 0 ? hPrev2 : hs2.getRow(t - 1); - INDArray hst2 = Transforms.tanh((whh.mmul(hst)).add((whh2.mmul(hs2Row)).add(bh2))); // hidden state 2 + hPrev2 = Transforms.tanh((whh.mmul(hs.getRow(t)).add(whh2.mmul(hPrev2)).add(bh2))); // hidden state 2 if (hs2 == null) { - hs2 = init(seqLength, hst2); + hs2 = init(seqLength, hPrev2.shape()); } - hs2.putRow(t, hst2); + hs2.putRow(t, hPrev2.dup()); - INDArray yst = (wh2y.mmul(hst2)).add(by); // unnormalized log probabilities for next chars + INDArray yst = wh2y.mmul(hs2.getRow(t)).add(by); // unnormalized log probabilities for next chars if (ys == null) { - ys = init(seqLength, yst); + ys = init(seqLength, yst.shape()); } ys.putRow(t, yst); INDArray pst = Nd4j.getExecutioner().execAndReturn(new SoftMax(yst)); // probabilities for next chars if (ps == null) { - ps = init(seqLength, pst); + ps = init(seqLength, pst.shape()); } ps.putRow(t, pst); @@ -258,7 +257,6 @@ public class StackedRNN extends RNN { Nd4j.getExecutioner().exec(new SetRange(dbh, -clip, clip)); Nd4j.getExecutioner().exec(new SetRange(dbh2, -clip, clip)); Nd4j.getExecutioner().exec(new SetRange(dby, -clip, clip)); - return loss; } @@ -273,26 +271,29 @@ public class StackedRNN extends RNN { int sampleSize = seqLength * 2; INDArray ixes = Nd4j.create(sampleSize); - INDArray h = hPrev; - INDArray h2 = hPrev2; + INDArray h = hPrev.dup(); + INDArray h2 = hPrev2.dup(); for (int t = 0; t < sampleSize; t++) { - h = Transforms.tanh((wxh.mmul(x)).add((whh.mmul(h)).add(bh))); - h2 = Transforms.tanh((whh.mmul(h)).add((whh2.mmul(h2)).add(bh2))); - INDArray y = (wh2y.mmul(h2)).add(by); + h = Transforms.tanh(((wxh.mmul(x)).add((whh.mmul(h)).add(bh)))); + h2 = Transforms.tanh(((whh.mmul(h)).add((whh2.mmul(h2)).add(bh2)))); + INDArray y = wh2y.mmul(h2).add(by); INDArray pm = Nd4j.getExecutioner().execAndReturn(new SoftMax(y)).ravel(); List<Pair<Integer, Double>> d = new LinkedList<>(); for (int pi = 0; pi < vocabSize; pi++) { d.add(new Pair<>(pi, pm.getDouble(0, pi))); } - EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d); + try { + EnumeratedDistribution<Integer> distribution = new EnumeratedDistribution<>(d); - int ix = distribution.sample(); + int ix = distribution.sample(); - x = Nd4j.zeros(vocabSize, 1); - x.putScalar(ix, 1); - ixes.putScalar(t, ix); + x = Nd4j.zeros(vocabSize, 1); + x.putScalar(ix, 1); + ixes.putScalar(t, ix); + } catch (Exception e) { + } } return getSampleString(ixes); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org