Author: tommaso Date: Thu Feb 25 11:17:56 2016 New Revision: 1732287 URL: http://svn.apache.org/viewvc?rev=1732287&view=rev Log: refactored MLN and improved SGN, commons-math 3.6
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java labs/yay/trunk/core/src/test/resources/word2vec/test.txt labs/yay/trunk/pom.xml Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java?rev=1732287&r1=1732286&r2=1732287&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java Thu Feb 25 11:17:56 2016 @@ -34,9 +34,8 @@ public class HotEncodedSample extends Sa @Override public double[] getInputs() { - double[] inputs = new double[1 + this.inputs.length * vocabularySize]; - inputs[0] = 1d; - int i = 1; + double[] inputs = new double[this.inputs.length * vocabularySize]; + int i = 0; for (double d : this.inputs) { double[] currentInput = hotEncode((int) d); System.arraycopy(currentInput, 0, inputs, i, currentInput.length); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1732287&r1=1732286&r2=1732287&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu Feb 25 11:17:56 2016 @@ -24,6 +24,7 @@ import org.apache.commons.math3.linear.M import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrixChangingVisitor; import org.apache.commons.math3.linear.RealMatrixPreservingVisitor; +import org.apache.commons.math3.linear.RealVector; import java.io.IOException; import java.nio.ByteBuffer; @@ -68,11 +69,44 @@ public class SkipGramNetwork { * the second row of weighs[0]Â holds the weights of each neuron in the second neuron of the second layer, etc. */ private RealMatrix[] weights; + private RealMatrix[] biases; + private Sample[] samples; private SkipGramNetwork(Configuration configuration) { this.configuration = configuration; this.weights = createRandomWeights(); + this.biases = createRandomBiases(); + } + + private RealMatrix[] createRandomBiases() { + Random r = new Random(); + + RealMatrix[] initialWeights = new RealMatrix[weights.length]; + + for (int i = 0; i < initialWeights.length; i++) { + + RealMatrix matrix = MatrixUtils.createRealMatrix(1, weights[i].getRowDimension()); + matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return 1;//r.nextInt(100000) / 10000001d; + } + + @Override + public double end() { + return 0; + } + }); + + initialWeights[i] = matrix; + } + return initialWeights; } public RealMatrix[] getWeights() { @@ -83,13 +117,21 @@ public class SkipGramNetwork { return configuration.vocabulary; } + public double[] predictOutput(double[] input) throws Exception { + + RealMatrix hidden = rectifierFunction.applyMatrix(MatrixUtils.createRowRealMatrix(input).multiply(weights[0].transpose()). + add(biases[0])); + RealMatrix pscores = softmaxActivationFunction.applyMatrix(hidden.multiply(weights[1].transpose()).add(biases[1])); + + RealVector d = pscores.getRowVector(0); + return d.toArray(); + } + private RealMatrix[] createRandomWeights() { Random r = new Random(); int[] conf = new int[]{configuration.inputs, configuration.vectorSize, configuration.outputs}; int[] layers = new int[conf.length]; - for (int i = 0; i < layers.length; i++) { - layers[i] = conf[i] + (i < layers.length - 1 ? 1 : 0); - } + System.arraycopy(conf, 0, layers, 0, layers.length); int weightsCount = layers.length - 1; RealMatrix[] initialWeights = new RealMatrix[weightsCount]; @@ -97,7 +139,6 @@ public class SkipGramNetwork { for (int i = 0; i < weightsCount; i++) { RealMatrix matrix = MatrixUtils.createRealMatrix(layers[i + 1], layers[i]); - final int finalI = i; matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { @@ -106,12 +147,7 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - if (finalI != weightsCount - 1 && row == 0) { - return 0d; - } else if (column == 0) { - return 1d; - } - return r.nextInt(100) / 101d; + return r.nextInt(10) / 1000000001d; } @Override @@ -126,6 +162,50 @@ public class SkipGramNetwork { } + private void evaluate() throws Exception { + double cc = 0; + double wc = 0; + for (Sample sample : samples) { + int window = configuration.window; + Collection<Integer> exps = new ArrayList<>(window - 1); + Collection<Integer> acts = new ArrayList<>(window - 1); + double[] inputs = sample.getInputs(); + double[] actualOutputs = predictOutput(inputs); + double[] expectedOutputs = sample.getOutputs(); + int j = 0; + for (int i = 0; i < window - 1; i++) { + int actualMax = getMaxIndex(actualOutputs, j, j + inputs.length - 1); + int expectedMax = getMaxIndex(expectedOutputs, j, j + inputs.length - 1); + exps.add(expectedMax); + acts.add(actualMax); + j += i + inputs.length - 2; + } + boolean c = true; + for (Integer a : acts) { + c &= exps.contains(a); + } + if (c) { + cc++; + } else { + wc++; + } + + } + System.out.println("accuracy: " + (cc / (wc + cc))); + } + + private int getMaxIndex(double[] array, int start, int end) { + double largest = array[start]; + int index = 0; + for (int i = start + 1; i < end; i++) { + if (array[i] >= largest) { + largest = array[i]; + index = i; + } + } + return index; + } + // --- batch gradient descent --- /** @@ -140,29 +220,65 @@ public class SkipGramNetwork { int iterations = 0; double cost = Double.MAX_VALUE; + + RealMatrix x = MatrixUtils.createRealMatrix(samples.length, samples[0].getInputs().length); + RealMatrix y = MatrixUtils.createRealMatrix(samples.length, samples[0].getOutputs().length); + int i = 0; + for (Sample sample : samples) { + x.setRow(i, ArrayUtils.addAll(sample.getInputs())); + y.setRow(i, ArrayUtils.addAll(sample.getOutputs())); + i++; + } + long start = System.currentTimeMillis(); while (true) { - if (iterations % (1 + (configuration.maxIterations / 100)) == 0) { - long time = (System.currentTimeMillis() - start) / 1000; + long time = (System.currentTimeMillis() - start) / 1000; + if (iterations % (1 + (configuration.maxIterations / 100)) == 0 || time % 300 < 2) { if (time > 60) { System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); } } - double newCost = 0; - RealMatrix x = MatrixUtils.createRealMatrix(samples.length, samples[0].getInputs().length); - RealMatrix y = MatrixUtils.createRealMatrix(samples.length, samples[0].getOutputs().length); - int i = 0; - for (Sample sample : samples) { - x.setRow(i, ArrayUtils.addAll(sample.getInputs())); - y.setRow(i, ArrayUtils.addAll(sample.getOutputs())); - i++; + if (iterations % 1000 == 0) { + evaluate(); } RealMatrix w0t = weights[0].transpose(); final RealMatrix w1t = weights[1].transpose(); RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(w0t)); + hidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return value + biases[0].getEntry(0, column); + } + + @Override + public double end() { + return 0; + } + }); RealMatrix scores = hidden.multiply(w1t); + scores.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return value + biases[1].getEntry(0, column); + } + + @Override + public double end() { + return 0; + } + }); RealMatrix probs = softmaxActivationFunction.applyMatrix(scores); @@ -221,11 +337,34 @@ public class SkipGramNetwork { return d; } }); - newCost = dataLoss + 0.5 * 0.03 * reg; + reg += weights[1].walkInOptimizedOrder(new RealMatrixPreservingVisitor() { + private double d = 0d; + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public void visit(int row, int column, double value) { + d += Math.pow(value, 2); + } + + @Override + public double end() { + return d; + } + }); + + double regLoss = 0.5 * configuration.regularizationLambda * reg; + double newCost = dataLoss + regLoss; + if (iterations == 0) { + System.out.println("started with cost = " + dataLoss + " + " + regLoss); + } if (Double.POSITIVE_INFINITY == newCost || newCost > cost) { throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost); - } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) { + } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations || cost - newCost < configuration.threshold)) { cost = newCost; System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost); break; @@ -238,7 +377,7 @@ public class SkipGramNetwork { // calculate the derivatives to update the parameters - RealMatrix dscores = probs; + RealMatrix dscores = probs.copy(); dscores.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { @@ -267,11 +406,25 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - if (column != 0) { - return value + 0.3 * w1t.getEntry(row, column); - } else { - return value; - } + return value + configuration.regularizationLambda * w1t.getEntry(row, column); + } + + @Override + public double end() { + return 0; + } + }); + + RealMatrix db2 = MatrixUtils.createRealMatrix(biases[1].getRowDimension(), biases[1].getColumnDimension()); + dscores.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public void visit(int row, int column, double value) { + db2.setEntry(0, column, db2.getEntry(0, column) + value); } @Override @@ -281,6 +434,40 @@ public class SkipGramNetwork { }); RealMatrix dhidden = dscores.multiply(weights[1]); + dhidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return value < 0 ? 0 : value; + } + + @Override + public double end() { + return 0; + } + }); + + RealMatrix db = MatrixUtils.createRealMatrix(biases[0].getRowDimension(), biases[0].getColumnDimension()); + dhidden.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public void visit(int row, int column, double value) { + db.setEntry(0, column, db.getEntry(0, column) + value); + } + + @Override + public double end() { + return 0; + } + }); RealMatrix dW = x.transpose().multiply(dhidden); dW.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @@ -291,11 +478,42 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - if (column != 0) { - return value + 0.03 * w0t.getEntry(row, column); - } else { - return value; - } + return value + configuration.regularizationLambda * w0t.getEntry(row, column); + } + + @Override + public double end() { + return 0; + } + }); + + // update bias + biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return value - configuration.alpha * db.getEntry(row, column); + } + + @Override + public double end() { + return 0; + } + }); + + biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return value - configuration.alpha * db2.getEntry(row, column); } @Override @@ -307,10 +525,7 @@ public class SkipGramNetwork { RealMatrix[] derivatives = new RealMatrix[]{dW.transpose(), dW2.transpose()}; // update the weights - RealMatrix[] updatedParameters = new RealMatrix[weights.length]; - for (int l = 0; l < weights.length; l++) { - RealMatrix realMatrix = weights[l].copy(); final int finalL = l; RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() { @@ -321,11 +536,7 @@ public class SkipGramNetwork { @Override public double visit(int row, int column, double value) { - if (!(row == 0 && value == 0d) && !(column == 0 && value == 1d)) { return value - configuration.alpha * derivatives[finalL].getEntry(row, column); - } else { - return value; - } } @Override @@ -333,10 +544,8 @@ public class SkipGramNetwork { return 0; } }; - realMatrix.walkInOptimizedOrder(visitor); - updatedParameters[l] = realMatrix; + weights[l].walkInOptimizedOrder(visitor); } - weights = updatedParameters; iterations++; } @@ -359,6 +568,10 @@ public class SkipGramNetwork { return new Builder(); } + public Sample[] getSamples() { + return samples; + } + // --- skip gram neural network configuration --- private static class Configuration { @@ -371,8 +584,9 @@ public class SkipGramNetwork { // user controlled parameters protected Path path; protected int maxIterations; - protected double alpha = 0.003d; - protected double threshold = 0.004d; + protected double alpha = 0.0001d; + protected double regularizationLambda = 0.000000000003; + protected double threshold = 0.0000000000004d; protected int vectorSize; protected int window; } @@ -405,43 +619,53 @@ public class SkipGramNetwork { Queue<List<byte[]>> fragments = getFragments(this.configuration.path, this.configuration.window); assert !fragments.isEmpty() : "could not read fragments"; System.out.println("generating vocabulary"); - List<String> vocabulary = getVocabulary(this.configuration.path); +// List<String> vocabulary = getVocabulary(this.configuration.path); + List<String> vocabulary = getVocabulary(fragments); assert !vocabulary.isEmpty() : "could not read vocabulary"; this.configuration.vocabulary = vocabulary; System.out.println("creating training set"); Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, this.configuration.window); fragments.clear(); - this.configuration.maxIterations = trainingSet.size() * 10; + this.configuration.maxIterations = trainingSet.size() * 100000; HotEncodedSample next = trainingSet.iterator().next(); - this.configuration.inputs = next.getInputs().length - 1; + this.configuration.inputs = next.getInputs().length; this.configuration.outputs = next.getOutputs().length; SkipGramNetwork network = new SkipGramNetwork(configuration); - network.learnWeights(trainingSet.toArray(new Sample[trainingSet.size()])); + network.samples = trainingSet.toArray(new Sample[trainingSet.size()]); + network.learnWeights(network.samples); return network; } + private List<String> getVocabulary(Queue<List<byte[]>> fragments) { + List<String> vocabulary = new LinkedList<>(); + for (List<byte[]> fragment : fragments) { + for (byte[] word : fragment) { + String s = new String(word); + if (!vocabulary.contains(s)) { + vocabulary.add(s); + } + } + } + return vocabulary; + } + private Collection<HotEncodedSample> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException { long start = System.currentTimeMillis(); Collection<HotEncodedSample> samples = new LinkedList<>(); List<byte[]> fragment; while ((fragment = fragments.poll()) != null) { - byte[] inputWord = null; List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1); - for (int i = 0; i < fragment.size(); i++) { - for (int j = 0; j < fragment.size(); j++) { - byte[] token = fragment.get(i); - if (i == j) { - inputWord = token; - } else { - outputWords.add(token); - } + int inputIdx = fragment.size() / 2; + byte[] inputWord = fragment.get(inputIdx); + for (int k = 0; k < fragment.size(); k++) { + if (k != inputIdx) { + outputWords.add(fragment.get(k)); } } - final byte[] finalInputWord = inputWord; double[] doubles = new double[window - 1]; for (int i = 0; i < doubles.length; i++) { @@ -449,7 +673,7 @@ public class SkipGramNetwork { } double[] inputs = new double[1]; - inputs[0] = (double) vocabulary.indexOf(new String(finalInputWord)); + inputs[0] = (double) vocabulary.indexOf(new String(inputWord)); samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size())); @@ -480,9 +704,11 @@ public class SkipGramNetwork { if (splitSize > w) { for (int j = 0; j < splitSize - w; j++) { List<byte[]> fragment = new ArrayList<>(w); - fragment.add(previous.append(split.get(j)).toString().getBytes()); + String str = split.get(j); + fragment.add(previous.append(str).toString().getBytes()); for (int i = 1; i < w; i++) { - fragment.add(split.get(i + j).getBytes()); + String s = split.get(i + j); + fragment.add(s.getBytes()); } // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration fragments.add(fragment); @@ -546,7 +772,7 @@ public class SkipGramNetwork { private String cleanString(CharBuffer charBuffer) { String s = charBuffer.toString(); - return s.toLowerCase().replaceAll("\\.", " ").replaceAll("\\;", " ").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", "").replaceAll("\\\"", ""); + return s.toLowerCase().replaceAll("\\.", " ");//.replaceAll("\\;", " ").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", "").replaceAll("\\\"", ""); } } } \ No newline at end of file Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java?rev=1732287&r1=1732286&r2=1732287&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java Thu Feb 25 11:17:56 2016 @@ -149,7 +149,6 @@ public class MultiLayerNetworkTest { MultiLayerNetwork nor = new MultiLayerNetwork(configuration, norRealMatrixSet); - assertEquals(0L, Math.round(nor.predictOutput(new double[]{1d, 0d})[0])); assertEquals(0L, Math.round(nor.predictOutput(new double[]{0d, 1d})[0])); assertEquals(1L, Math.round(nor.predictOutput(new double[]{0d, 0d})[0])); Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1732287&r1=1732286&r2=1732287&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Thu Feb 25 11:17:56 2016 @@ -18,9 +18,14 @@ */ package org.apache.yay; +import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.ml.distance.CanberraDistance; +import org.apache.commons.math3.ml.distance.ChebyshevDistance; import org.apache.commons.math3.ml.distance.DistanceMeasure; import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.commons.math3.ml.distance.ManhattanDistance; +import org.apache.commons.math3.util.FastMath; import org.junit.Test; import java.io.BufferedWriter; @@ -29,8 +34,10 @@ import java.io.FileWriter; import java.io.IOException; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.Date; import java.util.LinkedList; import java.util.List; @@ -42,64 +49,76 @@ public class SkipGramNetworkTest { @Test public void testWordVectorsLearningOnAbstracts() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile()); - SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build(); + int window = 3; + SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(window).fromTextAt(path).withDimension(10).build(); RealMatrix wv = network.getWeights()[0]; List<String> vocabulary = network.getVocabulary(); serialize(vocabulary, wv); - measure(vocabulary, wv); + evaluate(network, window); } @Test public void testWordVectorsLearningOnSentences() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile()); - SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build(); + int window = 3; + SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(window).fromTextAt(path).withDimension(10).build(); RealMatrix wv = network.getWeights()[0]; List<String> vocabulary = network.getVocabulary(); serialize(vocabulary, wv); - measure(vocabulary, wv); + evaluate(network, window); } @Test public void testWordVectorsLearningOnTestData() throws Exception { Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile()); - SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build(); + int window = 3; + SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(window).fromTextAt(path).withDimension(10).build(); + evaluate(network, window); + network.learnWeights(network.getSamples()); + evaluate(network, window); RealMatrix wv = network.getWeights()[0]; List<String> vocabulary = network.getVocabulary(); serialize(vocabulary, wv); - measure(vocabulary, wv); } private void measure(List<String> vocabulary, RealMatrix wordVectors) { System.out.println("measuring similarities"); Collection<DistanceMeasure> measures = new LinkedList<>(); measures.add(new EuclideanDistance()); -// measures.add(new DistanceMeasure() { -// @Override -// public double compute(double[] a, double[] b) { -// double dp = 0.0; -// double na = 0.0; -// double nb = 0.0; -// for (int i = 0; i < a.length; i++) { -// dp += a[i] * b[i]; -// na += Math.pow(a[i], 2); -// nb += Math.pow(b[i], 2); -// } -// double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb)); -// return 1 / cosineSimilarity; -// } -// -// @Override -// public String toString() { -// return "inverse cosine similarity distance measure"; -// } -// }); -// measures.add((DistanceMeasure) (a, b) -> { -// double da = FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a))); -// double db = FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b))); -// return Math.abs(db - da); -// }); + measures.add(new CanberraDistance()); + measures.add(new ChebyshevDistance()); + measures.add(new ManhattanDistance()); + measures.add(new DistanceMeasure() { + @Override + public double compute(double[] a, double[] b) { + double dp = 0.0; + double na = 0.0; + double nb = 0.0; + for (int i = 0; i < a.length; i++) { + dp += a[i] * b[i]; + na += Math.pow(a[i], 2); + nb += Math.pow(b[i], 2); + } + double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb)); + return 1 / cosineSimilarity; + } + + @Override + public String toString() { + return "inverse cosine similarity distance measure"; + } + }); + measures.add((DistanceMeasure) (a, b) -> { + double da = FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a))); + double db = FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b))); + return Math.abs(db - da); + }); for (DistanceMeasure distanceMeasure : measures) { - System.out.println("computing similarity using " + distanceMeasure); + System.out.println("*********************************************"); + System.out.println("*********************************************"); + System.out.println("*** similarity by " + distanceMeasure + "***"); + System.out.println("*********************************************"); + System.out.println("*********************************************"); computeSimilarities(vocabulary, wordVectors, distanceMeasure); } @@ -107,7 +126,7 @@ public class SkipGramNetworkTest { private void serialize(List<String> vocabulary, RealMatrix wordVectors) throws IOException { System.out.println("serializing word vectors"); - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv"))); + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors-" + new Date().toString() + ".csv"))); for (int i = 1; i < wordVectors.getColumnDimension(); i++) { double[] a = wordVectors.getColumnVector(i).toArray(); String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length)); @@ -163,12 +182,58 @@ public class SkipGramNetworkTest { } if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) { System.out.println(vocabulary.get(i - 1) + " -> " - + vocabulary.get(j0 - 1) + ", " - + vocabulary.get(j1 - 1) + ", " - + vocabulary.get(j2 - 1)); + + vocabulary.get(j0 - 1) +// + ", " +// + vocabulary.get(j1 - 1) +// + ", " +// + vocabulary.get(j2 - 1) + ); } else { System.err.println("no similarity for '" + vocabulary.get(i) + "' with " + distanceMeasure); } } } + + private void evaluate(SkipGramNetwork network, int window) throws Exception { + double cc = 0; + double wc = 0; + for (Sample sample : network.getSamples()) { + Collection<Integer> exps = new ArrayList<>(window - 1); + Collection<Integer> acts = new ArrayList<>(window - 1); + double[] inputs = sample.getInputs(); + double[] actualOutputs = network.predictOutput(inputs); + double[] expectedOutputs = sample.getOutputs(); + int j = 0; + for (int i = 0; i < window - 1; i++) { + int actualMax = getMaxIndex(actualOutputs, j, j + inputs.length - 1); + int expectedMax = getMaxIndex(expectedOutputs, j, j + inputs.length - 1); + exps.add(expectedMax); + acts.add(actualMax); + j += i + inputs.length - 2; + } + boolean c = true; + for (Integer a : acts) { + c &= exps.contains(a); + } + if (c) { + cc++; + } else { + wc++; + } + } + System.out.println("accuracy: " + (cc / (wc + cc))); + } + + private int getMaxIndex(double[] array, int start, int end) { + double largest = array[start]; + int index = 0; + for (int i = start + 1; i < end; i++) { + if (array[i] >= largest) { + largest = array[i]; + index = i; + } + } + return index; + } + } Modified: labs/yay/trunk/core/src/test/resources/word2vec/test.txt URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/resources/word2vec/test.txt?rev=1732287&r1=1732286&r2=1732287&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/resources/word2vec/test.txt (original) +++ labs/yay/trunk/core/src/test/resources/word2vec/test.txt Thu Feb 25 11:17:56 2016 @@ -1,3 +1,8 @@ the dog saw a cat the dog chased the cat -the cat climbed a tree \ No newline at end of file +the cat climbed a tree +a dog is similar to a cat +dogs eat cats +cats eat rats +rats eat everything +a rat saw something \ No newline at end of file Modified: labs/yay/trunk/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1732287&r1=1732286&r2=1732287&view=diff ============================================================================== --- labs/yay/trunk/pom.xml (original) +++ labs/yay/trunk/pom.xml Thu Feb 25 11:17:56 2016 @@ -80,7 +80,7 @@ <dependency> <groupId>org.apache.commons</groupId> <artifactId>commons-math3</artifactId> - <version>3.5</version> + <version>3.6</version> </dependency> </dependencies> </dependencyManagement> --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org