Author: tommaso Date: Wed Mar 2 11:13:32 2016 New Revision: 1733257 URL: http://svn.apache.org/viewvc?rev=1733257&view=rev Log: biases initialized to 0.001, per output word softmax, cached expanded representation in hot encoded samples, 0.5 decay on learning rate
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java?rev=1733257&r1=1733256&r2=1733257&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java Wed Mar 2 11:13:32 2016 @@ -25,6 +25,9 @@ import java.util.Arrays; */ public class HotEncodedSample extends Sample { + private double[] expandedInputs = null; + private double[] expandedOutputs = null; + private final int vocabularySize; public HotEncodedSample(double[] inputs, double[] outputs, int vocabularySize) { @@ -34,26 +37,32 @@ public class HotEncodedSample extends Sa @Override public double[] getInputs() { - double[] inputs = new double[this.inputs.length * vocabularySize]; - int i = 0; - for (double d : this.inputs) { - double[] currentInput = hotEncode((int) d); - System.arraycopy(currentInput, 0, inputs, i, currentInput.length); - i += vocabularySize; + if (expandedInputs == null) { + double[] inputs = new double[this.inputs.length * vocabularySize]; + int i = 0; + for (double d : this.inputs) { + double[] currentInput = hotEncode((int) d); + System.arraycopy(currentInput, 0, inputs, i, currentInput.length); + i += vocabularySize; + } + expandedInputs = inputs; } - return inputs; + return expandedInputs; } @Override public double[] getOutputs() { - double[] outputs = new double[this.outputs.length * vocabularySize]; - int i = 0; - for (double d : this.outputs) { - double[] currentOutput = hotEncode((int) d); - System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length); - i += vocabularySize; + if (expandedOutputs == null) { + double[] outputs = new double[this.outputs.length * vocabularySize]; + int i = 0; + for (double d : this.outputs) { + double[] currentOutput = hotEncode((int) d); + System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length); + i += vocabularySize; + } + expandedOutputs = outputs; } - return outputs; + return expandedOutputs; } private double[] hotEncode(int index) { Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1733257&r1=1733256&r2=1733257&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Wed Mar 2 11:13:32 2016 @@ -75,26 +75,18 @@ public class SkipGramNetwork { private SkipGramNetwork(Configuration configuration) { this.configuration = configuration; - this.weights = createRandomWeights(); - this.biases = createRandomBiases(); + this.weights = initWeights(); + this.biases = initBiases(); } - private RealMatrix[] createRandomBiases() { + private RealMatrix[] initBiases() { RealMatrix[] initialBiases = new RealMatrix[weights.length]; for (int i = 0; i < initialBiases.length; i++) { - RealMatrix matrix = MatrixUtils.createRealMatrix(1, weights[i].getRowDimension()); - - UniformRealDistribution uniformRealDistribution = new UniformRealDistribution(); - double[] vs = uniformRealDistribution.sample(matrix.getRowDimension() * matrix.getColumnDimension()); - int r = 0; - int c = 0; - for (double v : vs) { - matrix.setEntry(r % matrix.getRowDimension(), c % matrix.getColumnDimension(), v); - r++; - c++; - } + double[] data = new double[weights[i].getRowDimension()]; + Arrays.fill(data, 0.01d); + RealMatrix matrix = MatrixUtils.createRowRealMatrix(data); initialBiases[i] = matrix; } @@ -113,13 +105,21 @@ public class SkipGramNetwork { RealMatrix hidden = rectifierFunction.applyMatrix(MatrixUtils.createRowRealMatrix(input).multiply(weights[0].transpose()). add(biases[0])); - RealMatrix pscores = softmaxActivationFunction.applyMatrix(hidden.multiply(weights[1].transpose()).add(biases[1])); + RealMatrix scores = hidden.multiply(weights[1].transpose()).add(biases[1]); - RealVector d = pscores.getRowVector(0); + RealMatrix probs = scores.copy(); + int len = scores.getColumnDimension() - 1; + for (int d = 0; d < configuration.window - 1; d++) { + int startColumn = d * len / (configuration.window - 1); + RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + input.length); + probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn); + } + + RealVector d = probs.getRowVector(0); return d.toArray(); } - private RealMatrix[] createRandomWeights() { + private RealMatrix[] initWeights() { int[] conf = new int[]{configuration.inputs, configuration.vectorSize, configuration.outputs}; int[] layers = new int[conf.length]; System.arraycopy(conf, 0, layers, 0, layers.length); @@ -146,9 +146,10 @@ public class SkipGramNetwork { } - static void evaluate(SkipGramNetwork network, int window) throws Exception { + static double evaluate(SkipGramNetwork network) throws Exception { double cc = 0; double wc = 0; + int window = network.configuration.window; for (Sample sample : network.samples) { Collection<Integer> exps = new ArrayList<>(window - 1); Collection<Integer> acts = new ArrayList<>(window - 1); @@ -184,9 +185,7 @@ public class SkipGramNetwork { } } - if (cc > 0) { - System.out.println("accuracy: " + (cc / (wc + cc))); - } + return (cc / (wc + cc)); } private static int getMaxIndex(double[] array, int start, int end) { @@ -240,12 +239,8 @@ public class SkipGramNetwork { System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); } } - if (iterations % 1000 == 0) { - evaluate(this, this.configuration.window); - System.out.println("cost: " + cost); - } -// configuration.alpha = configuration.alpha * 0.999; + configuration.alpha = configuration.alpha * 0.5; RealMatrix w0t = weights[0].transpose(); final RealMatrix w1t = weights[1].transpose(); @@ -285,7 +280,13 @@ public class SkipGramNetwork { } }); - RealMatrix probs = softmaxActivationFunction.applyMatrix(scores); + RealMatrix probs = scores.copy(); + int len = scores.getColumnDimension() - 1; + for (int d = 0; d < configuration.window - 1; d++) { + int startColumn = d * len / (configuration.window - 1); + RealMatrix subMatrix = scores.getSubMatrix(0, scores.getRowDimension() - 1, startColumn, startColumn + x.getColumnDimension()); + probs.setSubMatrix(softmaxActivationFunction.applyMatrix(subMatrix).getData(), 0, startColumn); + } RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1); correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @@ -562,7 +563,7 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - return configuration.mu * value - configuration.alpha + dWt2.getEntry(row, column); + return configuration.mu * value + configuration.alpha - dWt2.getEntry(row, column); } @Override @@ -813,6 +814,11 @@ public class SkipGramNetwork { return this; } + public Builder withMaxIterations(int iterations) { + this.configuration.maxIterations = iterations; + return this; + } + public SkipGramNetwork build() throws Exception { System.out.println("reading fragments"); Queue<List<byte[]>> fragments = getFragments(this.configuration.path, this.configuration.window); @@ -825,7 +831,9 @@ public class SkipGramNetwork { System.out.println("creating training set"); Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, this.configuration.window); fragments.clear(); - this.configuration.maxIterations = trainingSet.size() * 100000; + if (this.configuration.maxIterations == 0) { + this.configuration.maxIterations = trainingSet.size() * 100000; + } HotEncodedSample next = trainingSet.iterator().next(); Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1733257&r1=1733256&r2=1733257&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Wed Mar 2 11:13:32 2016 @@ -43,47 +43,52 @@ public class SkipGramNetworkTest { @Test public void testWordVectorsLearningOnAbstracts() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile()); - int window = 3; SkipGramNetwork network = SkipGramNetwork.newModel(). - withWindow(window). + withWindow(3). fromTextAt(path). - withDimension(2). - withAlpha(0.003). - withLambda(0.00003). + withDimension(10). + withAlpha(1). + withLambda(0.003). + withMaxIterations(500). build(); RealMatrix wv = network.getWeights()[0]; List<String> vocabulary = network.getVocabulary(); serialize(vocabulary, wv); - SkipGramNetwork.evaluate(network, window); + System.err.println("accuracy: " + SkipGramNetwork.evaluate(network)); + measure(vocabulary, wv); } @Test public void testWordVectorsLearningOnSentences() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile()); - int window = 3; SkipGramNetwork network = SkipGramNetwork.newModel(). - withWindow(window). + withWindow(3). fromTextAt(path). - withDimension(10).build(); + withDimension(10). + withAlpha(1). + withLambda(0.03). + withMaxIterations(500). + build(); RealMatrix wv = network.getWeights()[0]; List<String> vocabulary = network.getVocabulary(); serialize(vocabulary, wv); - SkipGramNetwork.evaluate(network, window); + System.err.println("accuracy: " + SkipGramNetwork.evaluate(network)); + measure(vocabulary, wv); } @Test public void testWordVectorsLearningOnTestData() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile()); - int window = 3; SkipGramNetwork network = SkipGramNetwork.newModel(). - withWindow(window). + withWindow(3). fromTextAt(path). withDimension(2). - withAlpha(0.00002). + withAlpha(1). withLambda(0.03). - withThreshold(0.00000000003). + withThreshold(0.000003). + withMaxIterations(1000). build(); - SkipGramNetwork.evaluate(network, window); + System.err.println("accuracy: " + SkipGramNetwork.evaluate(network)); RealMatrix wv = network.getWeights()[0]; List<String> vocabulary = network.getVocabulary(); serialize(vocabulary, wv); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org