Author: tommaso Date: Wed Oct 28 14:08:28 2015 New Revision: 1711018 URL: http://svn.apache.org/viewvc?rev=1711018&view=rev Log: simplified wordvec test, back to serial matrix update impl
Added: labs/yay/trunk/core/src/test/resources/word2vec/test.txt (with props) Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1711018&r1=1711017&r2=1711018&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Wed Oct 28 14:08:28 2015 @@ -18,15 +18,19 @@ */ package org.apache.yay.core; -import org.apache.commons.math3.linear.Array2DRowRealMatrix; -import org.apache.commons.math3.linear.RealMatrix; -import org.apache.yay.*; - import java.util.Arrays; -import java.util.Collection; import java.util.Iterator; -import java.util.LinkedList; -import java.util.concurrent.*; + +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.yay.CostFunction; +import org.apache.yay.DerivativeUpdateFunction; +import org.apache.yay.LearningStrategy; +import org.apache.yay.NeuralNetwork; +import org.apache.yay.PredictionStrategy; +import org.apache.yay.TrainingExample; +import org.apache.yay.TrainingSet; +import org.apache.yay.WeightLearningException; /** * Back propagation learning algorithm for neural networks implementation (see @@ -46,8 +50,6 @@ public class BackPropagationLearningStra private final int batch; private final int maxIterations; - private final ExecutorService executorService = Executors.newCachedThreadPool(); - public BackPropagationLearningStrategy(double alpha, double threshold, PredictionStrategy<Double, Double> predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) { this(alpha, 1, threshold, predictionStrategy, costFunction, MAX_ITERATIONS); @@ -108,7 +110,7 @@ public class BackPropagationLearningStra } else if (iterations > 1 && (cost == newCost || newCost < threshold || iterations > maxIterations)) { System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters " + Arrays.toString(hypothesis.getParameters())); break; - } else if (Double.isNaN(newCost)) { + } else if (Double.isNaN(newCost)){ throw new RuntimeException("failed to converge at iteration " + iterations + " with alpha " + alpha + " : cost calculation underflow"); } @@ -137,7 +139,14 @@ public class BackPropagationLearningStra RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length]; for (int l = 0; l < weightsMatrixSet.length; l++) { double[][] updatedWeights = weightsMatrixSet[l].getData(); - updateMatrix(derivatives, alpha, l, updatedWeights); + for (int i = 0; i < updatedWeights.length; i++) { + for (int j = 0; j < updatedWeights[i].length; j++) { + double curVal = updatedWeights[i][j]; + if (!(i == 0 && curVal == 0d) && !(j == 0 && curVal == 1d)) { + updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j]; + } + } + } if (updatedParameters[l] != null) { updatedParameters[l].setSubMatrix(updatedWeights, 0, 0); } else { @@ -147,36 +156,4 @@ public class BackPropagationLearningStra return updatedParameters; } - private void updateMatrix(final RealMatrix[] derivatives, final double alpha, final int l, final double[][] updatedWeights) { - Collection<Future<Double>> futures = new LinkedList<Future<Double>>(); - for (int i = 0; i < updatedWeights.length; i++) { - for (int j = 0; j < updatedWeights[i].length; j++) { - final int finalI = i; - final int finalJ = j; - Callable<Double> callable = new Callable<Double>() { - @Override - public Double call() throws Exception { - double curVal = updatedWeights[finalI][finalJ]; - double val; - if (!(finalI == 0 && curVal == 0d) && !(finalJ == 0 && curVal == 1d)) { - val = -alpha * derivatives[l].getData()[finalI][finalJ]; - updatedWeights[finalI][finalJ] = val; - } else { - val = curVal; - } - return val; - } - }; - futures.add(executorService.submit(callable)); - } - } - for (Future<Double> f : futures) { - try { - f.get(3, TimeUnit.SECONDS); - } catch (Exception e) { - throw new RuntimeException(e); - } - } - } - -} +} \ No newline at end of file Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1711018&r1=1711017&r2=1711018&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Wed Oct 28 14:08:28 2015 @@ -18,37 +18,21 @@ */ package org.apache.yay.core; +import org.apache.commons.math3.linear.Array2DRowRealMatrix; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.ml.distance.*; +import org.apache.commons.math3.stat.correlation.PearsonsCorrelation; +import org.apache.yay.*; +import org.junit.Test; + import java.io.BufferedReader; -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileOutputStream; -import java.io.FileWriter; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; -import java.io.ObjectOutputStream; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.Random; - -import org.apache.commons.math3.linear.Array2DRowRealMatrix; -import org.apache.commons.math3.linear.MatrixUtils; -import org.apache.commons.math3.linear.RealMatrix; -import org.apache.yay.ActivationFunction; -import org.apache.yay.Feature; -import org.apache.yay.Input; -import org.apache.yay.NeuralNetwork; -import org.apache.yay.TrainingExample; -import org.apache.yay.TrainingSet; -import org.junit.Test; +import java.util.*; -import static org.junit.Assert.*; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; /** * Integration test for using Yay to implement word vectors algorithms. @@ -68,89 +52,190 @@ public class WordVectorsTest { TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, fragments); TrainingExample<Double, Double> next = trainingSet.iterator().next(); - int inputSize = next.getFeatures().size() ; + int inputSize = next.getFeatures().size(); int outputSize = next.getOutput().length; - int hiddenSize = new Random().nextInt(50) + 15; - System.err.println("i:"+inputSize+",h:"+hiddenSize+",o:"+outputSize); + int hiddenSize = 50; RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize); Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>(); activationFunctions.put(0, new IdentityActivationFunction<Double>()); activationFunctions.put(1, new SoftmaxActivationFunction()); FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions); - BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.03d, 1, - BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(), 10); + BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.003d, 1, + BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LogisticRegressionCostFunction(), + 1000); NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy); RealMatrix[] learnedWeights = neuralNetwork.learn(trainingSet); - RealMatrix wordVectors = learnedWeights[learnedWeights.length - 1]; + RealMatrix wordVectors = learnedWeights[0]; - assertNotNull(wordVectors); + Collection<DistanceMeasure> measures = new LinkedList<DistanceMeasure>(); + measures.add(new EuclideanDistance()); + measures.add(new CanberraDistance()); + measures.add(new ChebyshevDistance()); + measures.add(new ManhattanDistance()); + measures.add(new EarthMoversDistance()); + measures.add(new DistanceMeasure() { + private final PearsonsCorrelation pearsonsCorrelation = new PearsonsCorrelation(); + + @Override + public double compute(double[] a, double[] b) { + return 1 / pearsonsCorrelation.correlation(a, b); + } - RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length); + @Override + public String toString() { + return "inverse pearson correlation distance measure"; + } + }); + measures.add(new DistanceMeasure() { + @Override + public double compute(double[] a, double[] b) { + double dp = 0.0; + double na = 0.0; + double nb = 0.0; + for (int i = 0; i < a.length; i++) { + dp += a[i] * b[i]; + na += Math.pow(a[i], 2); + nb += Math.pow(b[i], 2); + } + double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb)); + return 1 / cosineSimilarity; + } - BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt"))); - int m = 0; - for (String word : vocabulary) { - final Double[] doubles = hotEncode(word, vocabulary); - Input<Double> input = new TrainingExample<Double, Double>() { - @Override - public ArrayList<Feature<Double>> getFeatures() { - ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>(); - Feature<Double> byasFeature = new Feature<Double>(); - byasFeature.setValue(1d); - features.add(byasFeature); - for (Double d : doubles) { - Feature<Double> f = new Feature<Double>(); - f.setValue(d); - features.add(f); - } - return features; - } - - @Override - public Double[] getOutput() { - return new Double[0]; - } - }; - Double[] predict = neuralNetwork.predict(input); - assertNotNull(predict); - double[] row = new double[predict.length]; - for (int x = 0; x < row.length; x++) { - row[x] = predict[x]; - } - mappingsMatrix.setRow(m, row); - m++; - - String vectorString = Arrays.toString(predict); - bufferedWriter.append(vectorString); - bufferedWriter.newLine(); - - Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size()); - assertNotNull(wordVec1); - Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size()); - assertNotNull(wordVec2); - Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size()); - assertNotNull(wordVec3); - - String word1 = hotDecode(wordVec1, vocabulary); - assertNotNull(word1); - assertTrue(vocabulary.contains(word1)); - String word2 = hotDecode(wordVec2, vocabulary); - assertNotNull(word2); - assertTrue(vocabulary.contains(word2)); - String word3 = hotDecode(wordVec3, vocabulary); - assertNotNull(word3); - assertTrue(vocabulary.contains(word3)); + @Override + public String toString() { + return "inverse cosine similarity distance measure"; + } + }); - System.out.println(word + " -> " + word1 + " " + word2 + " " + word3); + for (DistanceMeasure distanceMeasure : measures) { + System.out.println("computing similarity using " + distanceMeasure); + computeSimilarities(vocabulary, wordVectors, distanceMeasure); } - bufferedWriter.flush(); - bufferedWriter.close(); - ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin"))); - MatrixUtils.serializeRealMatrix(mappingsMatrix, os); + assertNotNull(wordVectors); + +// RealMatrix mappingsMatrix = MatrixUtils.createRealMatrix(next.getFeatures().size(), next.getOutput().length); +// +// BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.txt"))); +// int m = 0; +// for (String word : vocabulary) { +// final Double[] doubles = hotEncode(word, vocabulary); +// Input<Double> input = new TrainingExample<Double, Double>() { +// @Override +// public ArrayList<Feature<Double>> getFeatures() { +// ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>(); +// Feature<Double> byasFeature = new Feature<Double>(); +// byasFeature.setValue(1d); +// features.add(byasFeature); +// for (Double d : doubles) { +// Feature<Double> f = new Feature<Double>(); +// f.setValue(d); +// features.add(f); +// } +// return features; +// } +// +// @Override +// public Double[] getOutput() { +// return new Double[0]; +// } +// }; +// Double[] predict = neuralNetwork.predict(input); +// assertNotNull(predict); +// double[] row = new double[predict.length]; +// for (int x = 0; x < row.length; x++) { +// row[x] = predict[x]; +// } +// mappingsMatrix.setRow(m, row); +// m++; +// +// String vectorString = Arrays.toString(predict); +// bufferedWriter.append(vectorString); +// bufferedWriter.newLine(); +// +// Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size()); +// assertNotNull(wordVec1); +// Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * vocabulary.size()); +// assertNotNull(wordVec2); +// Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * vocabulary.size()); +// assertNotNull(wordVec3); +// +// String word1 = hotDecode(wordVec1, vocabulary); +// assertNotNull(word1); +// assertTrue(vocabulary.contains(word1)); +// String word2 = hotDecode(wordVec2, vocabulary); +// assertNotNull(word2); +// assertTrue(vocabulary.contains(word2)); +// String word3 = hotDecode(wordVec3, vocabulary); +// assertNotNull(word3); +// assertTrue(vocabulary.contains(word3)); +// +// System.out.println(word + " generates " + word1 + " " + word2 + " " + word3); +// } +// bufferedWriter.flush(); +// bufferedWriter.close(); +// +// ObjectOutputStream os = new ObjectOutputStream(new FileOutputStream(new File("target/sg-vectors.bin"))); +// MatrixUtils.serializeRealMatrix(mappingsMatrix, os); + } + + private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) { + for (int i = 1; i < wordVectors.getColumnDimension(); i++) { + double[] subjectVector = wordVectors.getColumn(i); + subjectVector = Arrays.copyOfRange(subjectVector, 1, subjectVector.length); + double maxSimilarity = -Double.MAX_VALUE; + double maxSimilarity1 = -Double.MAX_VALUE; + double maxSimilarity2 = -Double.MAX_VALUE; + double[] bestVector = null; + double[] bestVector1 = null; + double[] bestVector2 = null; + int j0 = -1; + int j1 = -1; + int j2 = -1; + for (int j = 1; j < wordVectors.getColumnDimension(); j++) { + if (i != j) { + double[] vector = wordVectors.getColumn(j); + vector = Arrays.copyOfRange(vector, 1, vector.length); + double similarity = 1 / distanceMeasure.compute(subjectVector, vector); + if (similarity > maxSimilarity) { + maxSimilarity2 = maxSimilarity1; + bestVector2 = bestVector1; + j2 = j1; + + maxSimilarity1 = maxSimilarity; + bestVector1 = bestVector; + j1 = j0; + + maxSimilarity = similarity; + bestVector = vector; + j0 = j; + } else if (similarity > maxSimilarity1) { + maxSimilarity2 = maxSimilarity1; + bestVector2 = bestVector1; + j2 = j1; + + maxSimilarity1 = similarity; + bestVector1 = vector; + j1 = j; + } else if (similarity > maxSimilarity2) { + maxSimilarity2 = similarity; + bestVector2 = vector; + j2 = j; + } + } + } + if (bestVector != null && i > 0 && j0 > 0 && j1 > 0 && j2 > 0) { + System.out.println(vocabulary.get(i - 1) + " is similar to " + + vocabulary.get(j0 - 1) + ", " + + vocabulary.get(j1 - 1) + ", " + + vocabulary.get(j2 - 1)); + } else { + throw new RuntimeException(); + } + } } private String hotDecode(Double[] doubles, List<String> vocabulary) { @@ -232,6 +317,7 @@ public class WordVectorsTest { } } } + Collections.sort(vocabulary); return vocabulary; } @@ -261,7 +347,7 @@ public class WordVectorsTest { } private Collection<String> getSentences() throws IOException { - InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/sentences.txt"); + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/test.txt"); BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resourceAsStream)); Collection<String> sentences = new LinkedList<String>(); String line; Added: labs/yay/trunk/core/src/test/resources/word2vec/test.txt URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/resources/word2vec/test.txt?rev=1711018&view=auto ============================================================================== --- labs/yay/trunk/core/src/test/resources/word2vec/test.txt (added) +++ labs/yay/trunk/core/src/test/resources/word2vec/test.txt Wed Oct 28 14:08:28 2015 @@ -0,0 +1,3 @@ +the dog saw a cat +the dog chased the cat +the cat climbed a tree \ No newline at end of file Propchange: labs/yay/trunk/core/src/test/resources/word2vec/test.txt ------------------------------------------------------------------------------ svn:eol-style = native --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org