Author: tommaso Date: Thu Feb 18 10:13:42 2016 New Revision: 1731036 URL: http://svn.apache.org/viewvc?rev=1731036&view=rev Log: refactored SFFNN to MLN, added ReLU function and (compact) skip-gram
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java - copied, changed from r1724846, labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java (with props) labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (with props) labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java - copied, changed from r1724846, labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (with props) Removed: labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java Modified: labs/yay/trunk/core/pom.xml Modified: labs/yay/trunk/core/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1731036&r1=1731035&r2=1731036&view=diff ============================================================================== --- labs/yay/trunk/core/pom.xml (original) +++ labs/yay/trunk/core/pom.xml Thu Feb 18 10:13:42 2016 @@ -51,7 +51,6 @@ <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> <version>18.0</version> - <scope>test</scope> </dependency> </dependencies> <build> Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java (from r1724846, labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java&r1=1724846&r2=1731036&rev=1731036&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java Thu Feb 18 10:13:42 2016 @@ -28,13 +28,11 @@ import java.util.Arrays; import java.util.Random; /** - * A shallow feed forward neural network. + * A multi layer feed forward neural network. * It learns its weights through backpropagation algorithm via stochastic gradient descent applied to a collection of * training samples. - * Each example is a real vectors whose first N elements (identified by the no. of network input units) are the actual - * outputs and the remaining elements are the input features. */ -public class ShallowFeedForwardNeuralNetwork { +public class MultiLayerNetwork { private final Configuration configuration; @@ -51,12 +49,12 @@ public class ShallowFeedForwardNeuralNet */ private RealMatrix[] weights; - public ShallowFeedForwardNeuralNetwork(Configuration configuration) { + public MultiLayerNetwork(Configuration configuration) { this.configuration = configuration; this.weights = createRandomWeights(); } - public ShallowFeedForwardNeuralNetwork(Configuration configuration, RealMatrix[] weights) { + public MultiLayerNetwork(Configuration configuration, RealMatrix[] weights) { this.configuration = configuration; this.weights = weights; } @@ -143,7 +141,7 @@ public class ShallowFeedForwardNeuralNet if (Double.POSITIVE_INFINITY == newCost) { throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost); - } else if (iterations > 1 && (cost == newCost || newCost < configuration.threshold || iterations > configuration.maxIterations)) { + } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) { System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost); break; } else if (Double.isNaN(newCost)) { @@ -204,13 +202,13 @@ public class ShallowFeedForwardNeuralNet } // compute derivatives - for (int i = 0; i < deltas.length; i++) { - ds[i] = deltas[i].scalarMultiply(1d / size); - } +// for (int i = 0; i < deltas.length; i++) { +// ds[i] = deltas[i].scalarMultiply(1d / size); +// } // regularization int l = 0; - for (RealMatrix d : ds) { + for (RealMatrix d : deltas) { final int finalL = l; d.walkInOptimizedOrder(new RealMatrixChangingVisitor() { @Override @@ -221,7 +219,7 @@ public class ShallowFeedForwardNeuralNet @Override public double visit(int row, int column, double value) { if (column != 0) { - return value + configuration.alpha * weights[finalL].getEntry(row, column); + return value + configuration.alpha * weights[finalL].getEntry(row, column); // assuming regularization factor == learning rate } else { return value; } @@ -235,7 +233,7 @@ public class ShallowFeedForwardNeuralNet l++; } - return ds; + return deltas; } private RealVector calculateDeltaVector(RealMatrix weight, RealVector activationsVector, RealVector nextLayerDelta) { @@ -294,7 +292,17 @@ public class ShallowFeedForwardNeuralNet res += yo * Math.log(ho) + (1d - yo) * Math.log(1d - ho); } + return (-1d / size) * res; + +// Double res = 0d; +// +// for (int i = 0; i < predictedOutput.length; i++) { +// Double so = expectedOutput[i]; +// Double po = predictedOutput[i]; +// res -= so * Math.log(po); +// } +// return res; } // --- feed forward --- Added: labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java?rev=1731036&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java Thu Feb 18 10:13:42 2016 @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealMatrixChangingVisitor; + +/** + * Rectifier (aka ReLU) activation function + */ +public class RectifierFunction implements ActivationFunction { + @Override + public RealMatrix applyMatrix(RealMatrix weights) { + RealMatrix matrix = weights.copy(); + matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return Math.max(0, value); + } + + @Override + public double end() { + return 0; + } + }); + return matrix; + } +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/RectifierFunction.java ------------------------------------------------------------------------------ svn:eol-style = native Added: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1731036&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu Feb 18 10:13:42 2016 @@ -0,0 +1,549 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import com.google.common.base.Splitter; +import org.apache.commons.lang3.ArrayUtils; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealMatrixChangingVisitor; +import org.apache.commons.math3.linear.RealMatrixPreservingVisitor; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.channels.SeekableByteChannel; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.regex.Pattern; + +/** + * A skip-gram neural network. + * It learns its weights through backpropagation algorithm via batch gradient descent applied to a collection of + * hot encoded training samples. + */ +public class SkipGramNetwork { + + private final Configuration configuration; + private final RectifierFunction rectifierFunction = new RectifierFunction(); + private final SoftmaxActivationFunction softmaxActivationFunction = new SoftmaxActivationFunction(); + + /** + * Each RealMatrix maps weights between two layers. + * E.g.: weights[0] controls function mapping from layer 0 to layer 1. + * If network has 4 units in layer 1 and 5 units in layer 2, then weights[0] will be of dimension 5x4, plus bias terms. + * A network having layers with 3, 4 and 2 units each will have the following weights matrix dimensions: + * - weights[0] : 4x3 + * - weights[1] : 2x4 + * <p> + * the first row of weighs[0]Â matrix holds the weights of each neuron in the first neuron of the second layer, + * the second row of weighs[0]Â holds the weights of each neuron in the second neuron of the second layer, etc. + */ + private RealMatrix[] weights; + + + private SkipGramNetwork(Configuration configuration) { + this.configuration = configuration; + this.weights = createRandomWeights(); + } + + public RealMatrix[] getWeights() { + return weights; + } + + public List<String> getVocabulary() { + return configuration.vocabulary; + } + + private RealMatrix[] createRandomWeights() { + Random r = new Random(); + int[] conf = new int[]{configuration.inputs, configuration.vectorSize, configuration.outputs}; + int[] layers = new int[conf.length]; + for (int i = 0; i < layers.length; i++) { + layers[i] = conf[i] + (i < layers.length - 1 ? 1 : 0); + } + int weightsCount = layers.length - 1; + + RealMatrix[] initialWeights = new RealMatrix[weightsCount]; + + for (int i = 0; i < weightsCount; i++) { + + RealMatrix matrix = MatrixUtils.createRealMatrix(layers[i + 1], layers[i]); + final int finalI = i; + matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + if (finalI != weightsCount - 1 && row == 0) { + return 0d; + } else if (column == 0) { + return 1d; + } + return r.nextInt(100) / 101d; + } + + @Override + public double end() { + return 0; + } + }); + + initialWeights[i] = matrix; + } + return initialWeights; + + } + + // --- batch gradient descent --- + + /** + * perform weights learning from the training examples using batch gradient descent algorithm + * + * @param samples the training examples + * @return the final cost with the updated weights + * @throws Exception if BGD fails to converge or any numerical error happens + */ + public double learnWeights(Sample... samples) throws Exception { + + int iterations = 0; + + double cost = Double.MAX_VALUE; + long start = System.currentTimeMillis(); + while (true) { + if (iterations % (1 + (configuration.maxIterations / 100)) == 0) { + long time = (System.currentTimeMillis() - start) / 1000; + if (time > 60) { + System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); + } + } + double newCost = 0; + RealMatrix x = MatrixUtils.createRealMatrix(samples.length, samples[0].getInputs().length); + RealMatrix y = MatrixUtils.createRealMatrix(samples.length, samples[0].getOutputs().length); + int i = 0; + for (Sample sample : samples) { + x.setRow(i, ArrayUtils.addAll(sample.getInputs())); + y.setRow(i, ArrayUtils.addAll(sample.getOutputs())); + i++; + } + + RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(weights[0].transpose())); + RealMatrix scores = hidden.multiply(weights[1].transpose()); + + RealMatrix probs = softmaxActivationFunction.applyMatrix(scores); + + RealMatrix correctLogProbs = MatrixUtils.createRealMatrix(x.getRowDimension(), 1); + correctLogProbs.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return -Math.log(probs.getEntry(row, getMaxIndex(y.getRow(row)))); + } + + @Override + public double end() { + return 0; + } + }); + double dataLoss = correctLogProbs.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { + private double d = 0; + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public void visit(int row, int column, double value) { + d += value; + } + + @Override + public double end() { + return d; + } + }) / samples.length; + + double reg = 0d; + reg += weights[0].walkInOptimizedOrder(new RealMatrixPreservingVisitor() { + private double d = 0d; + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public void visit(int row, int column, double value) { + d += Math.pow(value, 2); + } + + @Override + public double end() { + return d; + } + }); + newCost = dataLoss + 0.5 * 0.03 * reg; + + if (Double.POSITIVE_INFINITY == newCost || newCost > cost) { + throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost); + } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) { + cost = newCost; + System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost); + break; + } else if (Double.isNaN(newCost)) { + throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost calculation underflow"); + } + + // update registered cost + cost = newCost; + + // calculate the derivatives to update the parameters + + RealMatrix dscores = probs; + dscores.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return y.getEntry(row, column) == 1 ? (value - 1) / samples.length : value / samples.length; + } + + @Override + public double end() { + return 0; + } + }); + + + RealMatrix dW2 = hidden.transpose().multiply(dscores); + + dW2.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + if (column != 0) { + return value + 0.03 * weights[1].transpose().getEntry(row, column); + } else { + return value; + } + } + + @Override + public double end() { + return 0; + } + }); + + RealMatrix dhidden = dscores.multiply(weights[1]); + + RealMatrix dW = x.transpose().multiply(dhidden); + dW.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + if (column != 0) { + return value + 0.03 * weights[0].transpose().getEntry(row, column); + } else { + return value; + } + } + + @Override + public double end() { + return 0; + } + }); + + RealMatrix[] derivatives = new RealMatrix[]{dW.transpose(), dW2.transpose()}; + + // update the weights + RealMatrix[] updatedParameters = new RealMatrix[weights.length]; + + for (int l = 0; l < weights.length; l++) { + RealMatrix realMatrix = weights[l].copy(); + final int finalL = l; + RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() { + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + if (!(row == 0 && value == 0d) && !(column == 0 && value == 1d)) { + return value - configuration.alpha * derivatives[finalL].getEntry(row, column); + } else { + return value; + } + } + + @Override + public double end() { + return 0; + } + }; + realMatrix.walkInOptimizedOrder(visitor); + updatedParameters[l] = realMatrix; + } + weights = updatedParameters; + + iterations++; + } + return cost; + } + + private int getMaxIndex(double[] array) { + double largest = array[0]; + int index = 0; + for (int i = 1; i < array.length; i++) { + if (array[i] >= largest) { + largest = array[i]; + index = i; + } + } + return index; + } + + public static SkipGramNetwork.Builder newModel() { + return new Builder(); + } + + // --- skip gram neural network configuration --- + + private static class Configuration { + // internal parameters + protected int outputs; + protected int inputs; + + protected List<String> vocabulary; + + // user controlled parameters + protected Path path; + protected int maxIterations; + protected double alpha = 0.001d; + protected double threshold = 0.004d; + protected int vectorSize; + protected int window; + } + + public static class Builder { + private final Configuration configuration; + + public Builder() { + this.configuration = new Configuration(); + } + + + public Builder withWindow(int w) { + this.configuration.window = w; + return this; + } + + public Builder fromTextAt(Path path) { + this.configuration.path = path; + return this; + } + + public Builder withDimension(int d) { + this.configuration.vectorSize = d; + return this; + } + + public SkipGramNetwork build() throws Exception { + System.out.println("reading fragments"); + Queue<List<byte[]>> fragments = getFragments(this.configuration.path, this.configuration.window); + assert !fragments.isEmpty() : "could not read fragments"; + System.out.println("generating vocabulary"); + List<String> vocabulary = getVocabulary(this.configuration.path); + assert !vocabulary.isEmpty() : "could not read vocabulary"; + this.configuration.vocabulary = vocabulary; + + System.out.println("creating training set"); + Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, this.configuration.window); + fragments.clear(); + this.configuration.maxIterations = trainingSet.size() * 10; + + HotEncodedSample next = trainingSet.iterator().next(); + + this.configuration.inputs = next.getInputs().length - 1; + this.configuration.outputs = next.getOutputs().length; + + SkipGramNetwork network = new SkipGramNetwork(configuration); + network.learnWeights(trainingSet.toArray(new Sample[trainingSet.size()])); + return network; + } + + private Collection<HotEncodedSample> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException { + long start = System.currentTimeMillis(); + Collection<HotEncodedSample> samples = new LinkedList<>(); + List<byte[]> fragment; + while ((fragment = fragments.poll()) != null) { + byte[] inputWord = null; + List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1); + for (int i = 0; i < fragment.size(); i++) { + for (int j = 0; j < fragment.size(); j++) { + byte[] token = fragment.get(i); + if (i == j) { + inputWord = token; + } else { + outputWords.add(token); + } + } + } + final byte[] finalInputWord = inputWord; + + double[] doubles = new double[window - 1]; + for (int i = 0; i < doubles.length; i++) { + doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i))); + } + + double[] inputs = new double[1]; + inputs[0] = (double) vocabulary.indexOf(new String(finalInputWord)); + + samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size())); + + } + + long end = System.currentTimeMillis(); + System.out.println("training set created in " + (end - start) / 60000 + " minutes"); + + return samples; + } + + private Queue<List<byte[]>> getFragments(Path path, int w) throws IOException { + long start = System.currentTimeMillis(); + Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>(); + + ByteBuffer buf = ByteBuffer.allocate(100); + try (SeekableByteChannel sbc = Files.newByteChannel(path)) { + + String encoding = System.getProperty("file.encoding"); + StringBuilder previous = new StringBuilder(); + Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults(); + while (sbc.read(buf) > 0) { + buf.rewind(); + CharBuffer charBuffer = Charset.forName(encoding).decode(buf); + String string = cleanString(charBuffer); + List<String> split = splitter.splitToList(string); + int splitSize = split.size(); + if (splitSize > w) { + for (int j = 0; j < splitSize - w; j++) { + List<byte[]> fragment = new ArrayList<>(w); + fragment.add(previous.append(split.get(j)).toString().getBytes()); + for (int i = 1; i < w; i++) { + fragment.add(split.get(i + j).getBytes()); + } + // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration + fragments.add(fragment); + previous = new StringBuilder(); + } + previous = new StringBuilder().append(split.get(splitSize - 1)); + } else if (split.size() == w) { + previous.append(string); + } + buf.flip(); + } + } catch (IOException x) { + System.err.println("caught exception: " + x); + } finally { + buf.clear(); + } + long end = System.currentTimeMillis(); + System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")"); + return fragments; + } + + private List<String> getVocabulary(Path path) throws IOException { + Set<String> vocabulary = new HashSet<>(); + ByteBuffer buf = ByteBuffer.allocate(100); + try (SeekableByteChannel sbc = Files.newByteChannel(path)) { + + String encoding = System.getProperty("file.encoding"); + StringBuilder previous = new StringBuilder(); + Splitter splitter = Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults(); + while (sbc.read(buf) > 0) { + buf.rewind(); + CharBuffer charBuffer = Charset.forName(encoding).decode(buf); + String string = cleanString(charBuffer); + List<String> split = splitter.splitToList(string); + int splitSize = split.size(); + if (splitSize > 1) { + String term = previous.append(split.get(0)).toString(); + vocabulary.add(term.intern()); + for (int i = 1; i < splitSize - 1; i++) { + String term2 = split.get(i); + vocabulary.add(term2.intern()); + } + previous = new StringBuilder().append(split.get(splitSize - 1)); + } else if (split.size() == 1) { + previous.append(string); + } + buf.flip(); + } + } catch (IOException x) { + System.err.println("caught exception: " + x); + } finally { + buf.clear(); + } + List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()])); + Collections.sort(list); +// for (String iw : vocabulary) { +// System.out.println(iw +"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list))); +// } + return list; + } + + private String cleanString(CharBuffer charBuffer) { + String s = charBuffer.toString(); + return s.toLowerCase().replaceAll("\\.", " ").replaceAll("\\;", " ").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", "").replaceAll("\\\"", ""); + } + } +} \ No newline at end of file Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java ------------------------------------------------------------------------------ svn:eol-style = native Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java (from r1724846, labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java&r1=1724846&r2=1731036&rev=1731036&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java Thu Feb 18 10:13:42 2016 @@ -22,25 +22,27 @@ import org.apache.commons.math3.linear.A import org.apache.commons.math3.linear.RealMatrix; import org.junit.Test; +import java.util.Random; + import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; /** - * Tests for {@link ShallowFeedForwardNeuralNetwork} + * Tests for {@link MultiLayerNetwork} */ -public class ShallowFeedForwardNeuralNetworkTest { +public class MultiLayerNetworkTest { @Test public void testLearnAndPredict() throws Exception { - ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration(); - configuration.alpha = 0.0001d; + MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration(); + configuration.alpha = 0.00001d; configuration.layers = new int[]{3, 4, 1}; configuration.maxIterations = 10000; - configuration.threshold = 0.004d; + configuration.threshold = 0.00000004d; configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()}; - ShallowFeedForwardNeuralNetwork neuralNetwork = new ShallowFeedForwardNeuralNetwork(configuration); + MultiLayerNetwork neuralNetwork = new MultiLayerNetwork(configuration); assertNotNull(neuralNetwork); Sample[] samples = new Sample[3]; @@ -55,6 +57,21 @@ public class ShallowFeedForwardNeuralNet assertNotNull(doubles); assertEquals(0.9d, doubles[0], 0.2d); + + samples = createRandomSamples(10000); + cost = neuralNetwork.learnWeights(samples); + assertTrue(cost > 0 && cost < 10); + } + + private Sample[] createRandomSamples(int size) { + Random r = new Random(); + Sample[] samples = new Sample[size]; + for (int i = 0; i < size; i++) { + boolean l = r.nextBoolean(); + samples[i] = new Sample(new double[]{r.nextDouble(), r.nextDouble(), r.nextDouble()}, l ? new double[]{1d} : + new double[]{0d}); + } + return samples; } @Test @@ -63,14 +80,14 @@ public class ShallowFeedForwardNeuralNet RealMatrix singleAndLayerWeights = new Array2DRowRealMatrix(weights); RealMatrix[] andRealMatrixSet = new RealMatrix[]{singleAndLayerWeights}; - ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration(); + MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration(); configuration.alpha = 0.0001d; configuration.layers = new int[]{2, 1}; configuration.maxIterations = 10000; configuration.threshold = 0.004d; configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()}; - ShallowFeedForwardNeuralNetwork and = new ShallowFeedForwardNeuralNetwork(configuration, andRealMatrixSet); + MultiLayerNetwork and = new MultiLayerNetwork(configuration, andRealMatrixSet); assertEquals(0L, Math.round(and.predictOutput(new double[]{1d, 0d})[0])); assertEquals(0L, Math.round(and.predictOutput(new double[]{0d, 1d})[0])); @@ -84,14 +101,14 @@ public class ShallowFeedForwardNeuralNet RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights); RealMatrix[] orRealMatrixSet = new RealMatrix[]{singleOrLayerWeights}; - ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration(); + MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration(); configuration.alpha = 0.0001d; configuration.layers = new int[]{2, 1}; configuration.maxIterations = 10000; configuration.threshold = 0.004d; configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()}; - ShallowFeedForwardNeuralNetwork or = new ShallowFeedForwardNeuralNetwork(configuration, orRealMatrixSet); + MultiLayerNetwork or = new MultiLayerNetwork(configuration, orRealMatrixSet); assertEquals(1L, Math.round(or.predictOutput(new double[]{1d, 0d})[0])); assertEquals(1L, Math.round(or.predictOutput(new double[]{0d, 1d})[0])); @@ -105,14 +122,14 @@ public class ShallowFeedForwardNeuralNet RealMatrix singleNotLayerWeights = new Array2DRowRealMatrix(weights); RealMatrix[] notRealMatrixSet = new RealMatrix[]{singleNotLayerWeights}; - ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration(); + MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration(); configuration.alpha = 0.0001d; configuration.layers = new int[]{1, 1}; configuration.maxIterations = 10000; configuration.threshold = 0.004d; configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()}; - ShallowFeedForwardNeuralNetwork not = new ShallowFeedForwardNeuralNetwork(configuration, notRealMatrixSet); + MultiLayerNetwork not = new MultiLayerNetwork(configuration, notRealMatrixSet); assertEquals(1L, Math.round(not.predictOutput(new double[]{0d})[0])); assertEquals(0L, Math.round(not.predictOutput(new double[]{1d})[0])); } @@ -123,14 +140,14 @@ public class ShallowFeedForwardNeuralNet RealMatrix secondNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{-10d, 20d, 20d}}); RealMatrix[] norRealMatrixSet = new RealMatrix[]{firstNorLayerWeights, secondNorLayerWeights}; - ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration(); + MultiLayerNetwork.Configuration configuration = new MultiLayerNetwork.Configuration(); configuration.alpha = 0.0001d; configuration.layers = new int[]{2, 2, 1}; configuration.maxIterations = 10000; configuration.threshold = 0.004d; configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()}; - ShallowFeedForwardNeuralNetwork nor = new ShallowFeedForwardNeuralNetwork(configuration, norRealMatrixSet); + MultiLayerNetwork nor = new MultiLayerNetwork(configuration, norRealMatrixSet); assertEquals(0L, Math.round(nor.predictOutput(new double[]{1d, 0d})[0])); Added: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1731036&view=auto ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java (added) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java Thu Feb 18 10:13:42 2016 @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.junit.Test; + +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.Collection; +import java.util.LinkedList; +import java.util.List; + +/** + * Tests for skip gram network + */ +public class SkipGramNetworkTest { + + @Test + public void testWordVectorsLearningOnAbstracts() throws Exception { + Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile()); + SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build(); + RealMatrix wv = network.getWeights()[0]; + List<String> vocabulary = network.getVocabulary(); + serialize(vocabulary, wv); + measure(vocabulary, wv); + } + + @Test + public void testWordVectorsLearningOnSentences() throws Exception { + Path path = Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile()); + SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build(); + RealMatrix wv = network.getWeights()[0]; + List<String> vocabulary = network.getVocabulary(); + serialize(vocabulary, wv); + measure(vocabulary, wv); + } + + @Test + public void testWordVectorsLearningOnTestData() throws Exception { + Path path = Paths.get(getClass().getResource("/word2vec/test.txt").getFile()); + SkipGramNetwork network = SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build(); + RealMatrix wv = network.getWeights()[0]; + List<String> vocabulary = network.getVocabulary(); + serialize(vocabulary, wv); + measure(vocabulary, wv); + } + + private void measure(List<String> vocabulary, RealMatrix wordVectors) { + System.out.println("measuring similarities"); + Collection<DistanceMeasure> measures = new LinkedList<>(); + measures.add(new EuclideanDistance()); +// measures.add(new DistanceMeasure() { +// @Override +// public double compute(double[] a, double[] b) { +// double dp = 0.0; +// double na = 0.0; +// double nb = 0.0; +// for (int i = 0; i < a.length; i++) { +// dp += a[i] * b[i]; +// na += Math.pow(a[i], 2); +// nb += Math.pow(b[i], 2); +// } +// double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb)); +// return 1 / cosineSimilarity; +// } +// +// @Override +// public String toString() { +// return "inverse cosine similarity distance measure"; +// } +// }); +// measures.add((DistanceMeasure) (a, b) -> { +// double da = FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a))); +// double db = FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b))); +// return Math.abs(db - da); +// }); + for (DistanceMeasure distanceMeasure : measures) { + System.out.println("computing similarity using " + distanceMeasure); + computeSimilarities(vocabulary, wordVectors, distanceMeasure); + } + + } + + private void serialize(List<String> vocabulary, RealMatrix wordVectors) throws IOException { + System.out.println("serializing word vectors"); + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv"))); + for (int i = 1; i < wordVectors.getColumnDimension(); i++) { + double[] a = wordVectors.getColumnVector(i).toArray(); + String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length)); + csq = csq.substring(1, csq.length() - 1); + bufferedWriter.append(csq); + bufferedWriter.append(", "); + bufferedWriter.append(vocabulary.get(i - 1)); + bufferedWriter.newLine(); + } + bufferedWriter.flush(); + bufferedWriter.close(); + + // for post processing with dimensionality reduction (PCA, t-SNE, etc.): + // values: awk '{$hiddenSize=""; print $0}' target/sg-vectors.csv + // keys: awk '{print $hiddenSize}' target/sg-vectors.csv + } + + private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) { + for (int i = 1; i < wordVectors.getColumnDimension(); i++) { + double[] subjectVector = wordVectors.getColumn(i); + subjectVector = Arrays.copyOfRange(subjectVector, 1, subjectVector.length); + double maxSimilarity = -Double.MAX_VALUE; + double maxSimilarity1 = -Double.MAX_VALUE; + double maxSimilarity2 = -Double.MAX_VALUE; + int j0 = -1; + int j1 = -1; + int j2 = -1; + for (int j = 1; j < wordVectors.getColumnDimension(); j++) { + if (i != j) { + double[] vector = wordVectors.getColumn(j); + vector = Arrays.copyOfRange(vector, 1, vector.length); + double similarity = 1d / distanceMeasure.compute(subjectVector, vector); + if (similarity > maxSimilarity) { + maxSimilarity2 = maxSimilarity1; + j2 = j1; + + maxSimilarity1 = maxSimilarity; + j1 = j0; + + maxSimilarity = similarity; + j0 = j; + } else if (similarity > maxSimilarity1) { + maxSimilarity2 = maxSimilarity1; + j2 = j1; + + maxSimilarity1 = similarity; + j1 = j; + } else if (similarity > maxSimilarity2) { + maxSimilarity2 = similarity; + j2 = j; + } + } + } + if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) { + System.out.println(vocabulary.get(i - 1) + " -> " + + vocabulary.get(j0 - 1) + ", " + + vocabulary.get(j1 - 1) + ", " + + vocabulary.get(j2 - 1)); + } else { + System.err.println("no similarity for '" + vocabulary.get(i) + "' with " + distanceMeasure); + } + } + } +} Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java ------------------------------------------------------------------------------ svn:eol-style = native --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org