Author: tommaso Date: Mon Oct 24 13:14:59 2016 New Revision: 1766403 URL: http://svn.apache.org/viewvc?rev=1766403&view=rev Log: minor fixes
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java labs/yay/trunk/pom.xml Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java?rev=1766403&r1=1766402&r2=1766403&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/NNRunner.java Mon Oct 24 13:14:59 2016 @@ -48,8 +48,10 @@ public class NNRunner { boolean useChars = true; String text = ""; + String name = ""; if (args.length > 1 && args[1] != null) { Path path = Paths.get(args[1]); + name = path.getFileName().toString(); try { byte[] bytes = Files.readAllBytes(path); text = new String(bytes); @@ -81,10 +83,13 @@ public class NNRunner { } rnn.learn(); - int seed = random.nextInt(rnn.vocabSize); - System.out.println(rnn.sample(seed)); + + for (int i = 0; i < 10; i++) { + int seed = random.nextInt(rnn.vocabSize); + System.out.println(rnn.sample(seed)); + } try { - rnn.serialize("weights-"); + rnn.serialize(name + "-weights-"); } catch (IOException e) { throw new RuntimeException("cannot serialize weights", e); } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1766403&r1=1766402&r2=1766403&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Mon Oct 24 13:14:59 2016 @@ -58,6 +58,7 @@ public class RNN { protected final Map<String, Integer> charToIx; protected final Map<Integer, String> ixToChar; protected final List<String> data; + private final static double reg = 1e-8; // model parameters private final INDArray wxh; // input to hidden @@ -144,7 +145,7 @@ public class RNN { INDArray targets = getSequence(p + 1); // sample from the model every now and then - if (n % 1000 == 0 && n > 0) { + if (n % 100 == 0 && n > 0) { String txt = sample(inputs.getInt(0)); System.out.printf("\n---\n %s \n----\n", txt); } @@ -163,25 +164,25 @@ public class RNN { System.out.println("loss is NaN (over/underflow occured, try adjusting hyperparameters)"); break; } - if (n % 1000 == 0) { + if (n % 100 == 0) { System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print progress } // perform parameter update with Adagrad mWxh.addi(dWxh.mul(dWxh)); - wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(1e-8)))); + wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.add(reg)))); mWhh.addi(dWhh.mul(dWhh)); - whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(1e-8)))); + whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.add(reg)))); mWhy.addi(dWhy.mul(dWhy)); - why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(1e-8)))); + why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.add(reg)))); mbh.addi(dbh.mul(dbh)); - bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(1e-8)))); + bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.add(reg)))); mby.addi(dby.mul(dby)); - by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(1e-8)))); + by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.add(reg)))); p += seqLength; // move data pointer n++; // iteration counter @@ -260,8 +261,8 @@ public class RNN { INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); dWhh.addi(dhraw.mmul(hsRow.transpose())); dhNext = whh.transpose().mmul(dhraw); - } + // clip exploding gradients Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5)); Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5)); @@ -289,7 +290,7 @@ public class RNN { INDArray x = Nd4j.zeros(vocabSize, 1); x.putScalar(seedIx, 1); - int sampleSize = 200; + int sampleSize = 2 * seqLength; INDArray ixes = Nd4j.create(sampleSize); INDArray h = hPrev.dup(); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java?rev=1766403&r1=1766402&r2=1766403&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StackedRNN.java Mon Oct 24 13:14:59 2016 @@ -220,9 +220,6 @@ public class StackedRNN extends RNN { loss += -Transforms.log(ps.getRow(t).getRow(targets.getInt(t)), true).sumNumber().doubleValue(); // softmax (cross-entropy loss) } - this.hPrev = hs.getRow(inputs.length() - 1); - this.hPrev2 = hs2.getRow(inputs.length() - 1); - // backward pass: compute gradients going backwards INDArray dhNext = Nd4j.zerosLike(hs.getRow(0)); INDArray dh2Next = Nd4j.zerosLike(hs2.getRow(0)); @@ -241,7 +238,7 @@ public class StackedRNN extends RNN { INDArray dh2 = wh2y.transpose().mmul(dy).add(dh2Next); // dh = np.dot(Why.T, dy) + dhnext #Â backprop into h2 - INDArray dhraw2 = (Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); // dhraw = (1 - hs[t] * hs[t]) * dh # backprop through tanh nonlinearity + INDArray dhraw2 = (Nd4j.ones(hs2t.shape()).sub(hs2t.mul(hs2t))).mul(dh2); // backprop through tanh nonlinearity // INDArray dhraw2 = Nd4j.getExecutioner().execAndReturn(new SetRange(hst2, 0, Double.MAX_VALUE)).mul(dh2); // backprop through relu nonlinearity dbh2.addi(dhraw2); // dbh += dhraw INDArray hst = hs.getRow(t); @@ -254,12 +251,16 @@ public class StackedRNN extends RNN { // INDArray dhraw = Nd4j.getExecutioner().execAndReturn(new SetRange(hst, 0, Double.MAX_VALUE)).mul(dh); // backprop through relu nonlinearity dbh.addi(dhraw); - dWxh.addi(dhraw.mmul(xs.getRow(t))); + dWxh.addi(dhraw.mmul(xs.getRow(t))); // dWxh += np.dot(dhraw, xs[t].T) INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1); - dWhh.addi(dhraw.mmul(hsRow.transpose())); - dhNext = whh.transpose().mmul(dhraw); + dWhh.addi(dhraw.mmul(hsRow.transpose())); // dWhh += np.dot(dhraw, hs[t-1].T) + dhNext = whh.transpose().mmul(dhraw); // dhnext = np.dot(Whh.T, dhraw) } + + this.hPrev = hs.getRow(inputs.length() - 1); + this.hPrev2 = hs2.getRow(inputs.length() - 1); + // clip exploding gradients Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5)); Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5)); @@ -280,7 +281,7 @@ public class StackedRNN extends RNN { INDArray x = Nd4j.zeros(vocabSize, 1); x.putScalar(seedIx, 1); - int sampleSize = 200; + int sampleSize = seqLength * 2; INDArray ixes = Nd4j.create(sampleSize); INDArray h = hPrev.dup(); Modified: labs/yay/trunk/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1766403&r1=1766402&r2=1766403&view=diff ============================================================================== --- labs/yay/trunk/pom.xml (original) +++ labs/yay/trunk/pom.xml Mon Oct 24 13:14:59 2016 @@ -187,7 +187,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-surefire-plugin</artifactId> - <version>2.12</version> + <version>2.18.1</version> </plugin> </plugins> <resources> --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org