Author: tommaso Date: Fri Jan 15 16:25:44 2016 New Revision: 1724846 URL: http://svn.apache.org/viewvc?rev=1724846&view=rev Log: added sample API, wv test
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java (with props) labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java (with props) labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java (with props) Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java Added: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java?rev=1724846&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java Fri Jan 15 16:25:44 2016 @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import java.util.Arrays; + +/** + * an hot-encoded {@link Sample} + */ +public class HotEncodedSample extends Sample { + + private final int vocabularySize; + + public HotEncodedSample(double[] inputs, double[] outputs, int vocabularySize) { + super(inputs, outputs); + this.vocabularySize = vocabularySize; + } + + @Override + public double[] getInputs() { + double[] inputs = new double[1 + this.inputs.length * vocabularySize]; + inputs[0] = 1d; + int i = 1; + for (double d : this.inputs) { + double[] currentInput = hotEncode((int) d); + System.arraycopy(currentInput, 0, inputs, i, currentInput.length); + i += vocabularySize; + } + return inputs; + } + + @Override + public double[] getOutputs() { + double[] outputs = new double[this.outputs.length * vocabularySize]; + int i = 0; + for (double d : this.outputs) { + double[] currentOutput = hotEncode((int) d); + System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length); + i += vocabularySize; + } + return outputs; + } + + private double[] hotEncode(int index) { + double[] vector = new double[vocabularySize]; + Arrays.fill(vector, 0d); + vector[index] = 1d; + return vector; + } +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java ------------------------------------------------------------------------------ svn:eol-style = native Added: labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java?rev=1724846&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java Fri Jan 15 16:25:44 2016 @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +/** + * a training example + */ +public class Sample { + + protected final double[] inputs; + protected final double[] outputs; + + public Sample(double[] inputs, double[] outputs) { + this.inputs = inputs; + + this.outputs = outputs; + } + + /** + * get the inputs as a double vector + * + * @return a double array + */ + public double[] getInputs() { + double[] result = new double[inputs.length + 1]; + result[0] = 1d; + System.arraycopy(inputs, 0, result, 1, inputs.length); + return result; + } + + /** + * get the outputs as a double vector + * + * @return a double array + */ + public double[] getOutputs() { + return outputs; + } + +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java ------------------------------------------------------------------------------ svn:eol-style = native Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java?rev=1724846&r1=1724845&r2=1724846&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java Fri Jan 15 16:25:44 2016 @@ -61,6 +61,10 @@ public class ShallowFeedForwardNeuralNet this.weights = weights; } + public RealMatrix[] getWeights() { + return weights; + } + private RealMatrix[] createRandomWeights() { Random r = new Random(); int[] layers = new int[configuration.layers.length]; @@ -112,7 +116,8 @@ public class ShallowFeedForwardNeuralNet * @return the final cost with the updated weights * @throws Exception if SGD fails to converge or any numerical error happens */ - public double learnWeights(double[]... samples) throws Exception { + public double learnWeights(Sample... samples) throws Exception { + double newCost; int iterations = 0; @@ -126,11 +131,10 @@ public class ShallowFeedForwardNeuralNet } } // current training example - double[] sample = samples[iterations % samples.length]; + Sample sample = samples[iterations % samples.length]; - int outputLayerSize = configuration.layers[configuration.layers.length - 1]; - double[] expectedOutput = getSampleOutput(sample, outputLayerSize); - double[] input = getSampleInput(sample, outputLayerSize); + double[] expectedOutput = sample.getOutputs(); + double[] input = sample.getInputs(); double[] predictedOutput = predictOutput(input); // TODO : use debugOutput to avoid performing it again when calculating derivatives @@ -150,7 +154,7 @@ public class ShallowFeedForwardNeuralNet cost = newCost; // calculate the derivatives to update the parameters - RealMatrix[] derivatives = calculateDerivatives(input, expectedOutput, sample.length); + RealMatrix[] derivatives = calculateDerivatives(input, expectedOutput, samples.length); // update the weights weights = getUpdatedWeights(derivatives); @@ -160,20 +164,6 @@ public class ShallowFeedForwardNeuralNet return newCost; } - // --- sample parsing --- - - private double[] getSampleInput(double[] sample, int outputLayerSize) { - double[] input = Arrays.copyOfRange(sample, outputLayerSize, sample.length); - double[] result = new double[input.length + 1]; - result[0] = 1d; - System.arraycopy(input, 0, result, 1, input.length); - return result; - } - - private double[] getSampleOutput(double[] sample, int outputLayerSize) { - return Arrays.copyOfRange(sample, 0, outputLayerSize); - } - // --- backpropagation --- private RealMatrix[] calculateDerivatives(double[] input, double[] output, int size) throws Exception { Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java?rev=1724846&r1=1724845&r2=1724846&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java Fri Jan 15 16:25:44 2016 @@ -43,19 +43,11 @@ public class ShallowFeedForwardNeuralNet ShallowFeedForwardNeuralNetwork neuralNetwork = new ShallowFeedForwardNeuralNetwork(configuration); assertNotNull(neuralNetwork); - double[][] samples = new double[3][4]; - samples[0][0] = 0.1d; - samples[0][1] = 0.2d; - samples[0][2] = 0.3d; - samples[0][3] = 0.4d; - samples[1][0] = 0.5d; - samples[1][1] = 0.6d; - samples[1][2] = 0.7d; - samples[1][3] = 0.8d; - samples[2][0] = 0.9d; - samples[2][1] = 0.1d; - samples[2][2] = 0.2d; - samples[2][3] = 0.3d; + Sample[] samples = new Sample[3]; + samples[0] = new Sample(new double[]{0.2, 0.3, 0.4}, new double[]{0.1}); + samples[1] = new Sample(new double[]{0.6, 0.7, 0.8}, new double[]{0.5}); + samples[2] = new Sample(new double[]{0.1, 0.2, 0.3}, new double[]{0.9}); + double cost = neuralNetwork.learnWeights(samples); assertTrue(cost > 0 && cost < 10); Added: labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java?rev=1724846&view=auto ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java (added) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java Fri Jan 15 16:25:44 2016 @@ -0,0 +1,431 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import com.google.common.base.Splitter; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.channels.SeekableByteChannel; +import java.nio.charset.Charset; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.regex.Pattern; + +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; + +/** + * Integration test for using Yay to implement word vectors algorithms. + */ +public class WordVectorsTest { + + private static final boolean measure = false; + + private static final boolean serialize = true; + + @Test + public void testSGM() throws Exception { + + Path path = Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile()); + + System.out.println("reading fragments"); + int window = 3; + Queue<List<byte[]>> fragments = getFragments(path, window); + assertFalse(fragments.isEmpty()); + System.out.println("generating vocabulary"); + List<String> vocabulary = getVocabulary(path); + assertFalse(vocabulary.isEmpty()); + + System.out.println("creating training set"); + Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, fragments, window); + fragments.clear(); + HotEncodedSample next = trainingSet.iterator().next(); + + int inputSize = next.getInputs().length - 1; + int outputSize = next.getOutputs().length; + + int hiddenSize = 10; + System.out.println("initializing neural network"); + + ActivationFunction[] activationFunctions = new ActivationFunction[2]; + activationFunctions[0] = new IdentityActivationFunction(); + activationFunctions[1] = new SoftmaxActivationFunction(); + + ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration(); + configuration.alpha = 0.00001d; + configuration.layers = new int[]{inputSize, hiddenSize, outputSize}; + configuration.maxIterations = trainingSet.size(); + configuration.threshold = 0.00000004d; + configuration.activationFunctions = activationFunctions; + + ShallowFeedForwardNeuralNetwork neuralNetwork = new ShallowFeedForwardNeuralNetwork(configuration); + System.out.println("learning..."); + long start = System.currentTimeMillis(); + neuralNetwork.learnWeights(trainingSet.toArray(new Sample[trainingSet.size()])); + RealMatrix[] learnedWeights = neuralNetwork.getWeights(); + System.out.println("learning finished in " + (System.currentTimeMillis() - start) / 60000 + " minutes"); + + RealMatrix wordVectors = learnedWeights[0]; + + assertNotNull(wordVectors); + + if (serialize) { + serialize(vocabulary, wordVectors); + } + + if (measure) { + measure(vocabulary, wordVectors); + } + } + + private void measure(List<String> vocabulary, RealMatrix wordVectors) { + System.out.println("measuring similarities"); + Collection<DistanceMeasure> measures = new LinkedList<>(); + measures.add(new EuclideanDistance()); +// measures.add(new DistanceMeasure() { +// @Override +// public double compute(double[] a, double[] b) { +// double dp = 0.0; +// double na = 0.0; +// double nb = 0.0; +// for (int i = 0; i < a.length; i++) { +// dp += a[i] * b[i]; +// na += Math.pow(a[i], 2); +// nb += Math.pow(b[i], 2); +// } +// double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb)); +// return 1 / cosineSimilarity; +// } +// +// @Override +// public String toString() { +// return "inverse cosine similarity distance measure"; +// } +// }); +// measures.add((DistanceMeasure) (a, b) -> { +// double da = FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a))); +// double db = FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b))); +// return Math.abs(db - da); +// }); + for (DistanceMeasure distanceMeasure : measures) { + System.out.println("computing similarity using " + distanceMeasure); + computeSimilarities(vocabulary, wordVectors, distanceMeasure); + } + } + + private void serialize(List<String> vocabulary, RealMatrix wordVectors) throws IOException { + System.out.println("serializing word vectors"); + BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new File("target/sg-vectors.csv"))); + for (int i = 1; i < wordVectors.getColumnDimension(); i++) { + double[] a = wordVectors.getColumnVector(i).toArray(); + String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length)); + csq = csq.substring(1, csq.length() - 1); + bufferedWriter.append(csq); + bufferedWriter.append(", "); + bufferedWriter.append(vocabulary.get(i - 1)); + bufferedWriter.newLine(); + } + bufferedWriter.flush(); + bufferedWriter.close(); + + // for post processing with dimensionality reduction (PCA, t-SNE, etc.): + // values: awk '{$hiddenSize=""; print $0}' target/sg-vectors.csv + // keys: awk '{print $hiddenSize}' target/sg-vectors.csv + } + + private void computeSimilarities(List<String> vocabulary, RealMatrix wordVectors, DistanceMeasure distanceMeasure) { + for (int i = 1; i < wordVectors.getColumnDimension(); i++) { + double[] subjectVector = wordVectors.getColumn(i); + subjectVector = Arrays.copyOfRange(subjectVector, 1, subjectVector.length); + double maxSimilarity = -Double.MAX_VALUE; + double maxSimilarity1 = -Double.MAX_VALUE; + double maxSimilarity2 = -Double.MAX_VALUE; + int j0 = -1; + int j1 = -1; + int j2 = -1; + for (int j = 1; j < wordVectors.getColumnDimension(); j++) { + if (i != j) { + double[] vector = wordVectors.getColumn(j); + vector = Arrays.copyOfRange(vector, 1, vector.length); + double similarity = 1d / distanceMeasure.compute(subjectVector, vector); + if (similarity > maxSimilarity) { + maxSimilarity2 = maxSimilarity1; + j2 = j1; + + maxSimilarity1 = maxSimilarity; + j1 = j0; + + maxSimilarity = similarity; + j0 = j; + } else if (similarity > maxSimilarity1) { + maxSimilarity2 = maxSimilarity1; + j2 = j1; + + maxSimilarity1 = similarity; + j1 = j; + } else if (similarity > maxSimilarity2) { + maxSimilarity2 = similarity; + j2 = j; + } + } + } + if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) { + System.out.println(vocabulary.get(i - 1) + " -> " + + vocabulary.get(j0 - 1) + ", " + + vocabulary.get(j1 - 1) + ", " + + vocabulary.get(j2 - 1)); + } else { + System.err.println("no similarity for '" + vocabulary.get(i) + "' with " + distanceMeasure); + } + } + } + + private Collection<HotEncodedSample> createTrainingSet(final List<String> vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException { + long start = System.currentTimeMillis(); + Collection<HotEncodedSample> samples = new LinkedList<>(); + List<byte[]> fragment; + while ((fragment = fragments.poll()) != null) { + byte[] inputWord = null; + List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1); + for (int i = 0; i < fragment.size(); i++) { + for (int j = 0; j < fragment.size(); j++) { + byte[] token = fragment.get(i); + if (i == j) { + inputWord = token; + } else { + outputWords.add(token); + } + } + } + final byte[] finalInputWord = inputWord; + + double[] doubles = new double[window - 1]; + for (int i = 0; i < doubles.length; i++) { + doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i))); + } + + double[] inputs = new double[1]; + inputs[0] = (double) vocabulary.indexOf(new String(finalInputWord)); + + samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size())); + + } + + long end = System.currentTimeMillis(); + System.out.println("training set created in " + (end - start) / 60000 + " minutes"); + + return samples; + } + + public static double[] hotEncode(int index, int size) { + double[] vector = new double[size]; + Arrays.fill(vector, 0d); + vector[index] = 1d; + return vector; + } + + + private List<String> getVocabulary(Path path) throws IOException { + Set<String> vocabulary = new HashSet<>(); + ByteBuffer buf = ByteBuffer.allocate(100); + try (SeekableByteChannel sbc = Files.newByteChannel(path)) { + + String encoding = System.getProperty("file.encoding"); + StringBuilder previous = new StringBuilder(); + Splitter splitter = Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults(); + while (sbc.read(buf) > 0) { + buf.rewind(); + CharBuffer charBuffer = Charset.forName(encoding).decode(buf); + String string = cleanString(charBuffer); + List<String> split = splitter.splitToList(string); + int splitSize = split.size(); + if (splitSize > 1) { + String term = previous.append(split.get(0)).toString(); + vocabulary.add(term.intern()); + for (int i = 1; i < splitSize - 1; i++) { + String term2 = split.get(i); + vocabulary.add(term2.intern()); + } + previous = new StringBuilder().append(split.get(splitSize - 1)); + } else if (split.size() == 1) { + previous.append(string); + } + buf.flip(); + } + } catch (IOException x) { + System.err.println("caught exception: " + x); + } finally { + buf.clear(); + } + List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()])); + Collections.sort(list); +// for (String iw : vocabulary) { +// System.out.println(iw +"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list))); +// } + return list; + } + + private String cleanString(CharBuffer charBuffer) { + String s = charBuffer.toString(); + return s.toLowerCase().replaceAll("\\.", " ").replaceAll("\\;", " ").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", "").replaceAll("\\\"", ""); + } + + private List<String> getVocabulary(Collection<byte[]> sentences) { + long start = System.currentTimeMillis(); + List<String> vocabulary = new LinkedList<>(); + for (byte[] sentence : sentences) { + for (String token : new String(sentence).split(" ")) { + if (!vocabulary.contains(token)) { + vocabulary.add(token); + } + } + } + System.out.println("sorting vocabulary"); + Collections.sort(vocabulary); + long end = System.currentTimeMillis(); + System.out.println("vocabulary generated in " + (end - start) / 60000 + " minutes"); + return vocabulary; + } + + private Queue<List<byte[]>> getFragments(Path path, int w) throws IOException { + long start = System.currentTimeMillis(); + Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>(); + + ByteBuffer buf = ByteBuffer.allocate(100); + try (SeekableByteChannel sbc = Files.newByteChannel(path)) { + + String encoding = System.getProperty("file.encoding"); + StringBuilder previous = new StringBuilder(); + Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults(); + while (sbc.read(buf) > 0) { + buf.rewind(); + CharBuffer charBuffer = Charset.forName(encoding).decode(buf); + String string = cleanString(charBuffer); + List<String> split = splitter.splitToList(string); + int splitSize = split.size(); + if (splitSize > w) { + for (int j = 0; j < splitSize - w; j++) { + List<byte[]> fragment = new ArrayList<>(w); + fragment.add(previous.append(split.get(j)).toString().getBytes()); + for (int i = 1; i < w; i++) { + fragment.add(split.get(i + j).getBytes()); + } + // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration + fragments.add(fragment); + previous = new StringBuilder(); + } + previous = new StringBuilder().append(split.get(splitSize - 1)); + } else if (split.size() == w) { + previous.append(string); + } + buf.flip(); + } + } catch (IOException x) { + System.err.println("caught exception: " + x); + } finally { + buf.clear(); + } + long end = System.currentTimeMillis(); + System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")"); + return fragments; + } + + private Collection<String> getSentences() throws IOException { + Collection<String> sentences = new LinkedList<>(); + + InputStream resourceAsStream = getClass().getResourceAsStream("/word2vec/test.txt"); + BufferedReader bufferedReader = new BufferedReader(new InputStreamReader(resourceAsStream)); + String line; + while ((line = bufferedReader.readLine()) != null) { + String cleanLine = line.toLowerCase().replaceAll("\\.", " ").replaceAll(";", " ").replaceAll(",", " ").replaceAll(":", " ").replaceAll("-", ""); + sentences.add(cleanLine); + } + return sentences; + } + + private RealMatrix[] createRandomWeights(int inputSize, int hiddenSize, int outputSize) { + Random r = new Random(); + int weightsCount = 2; + + RealMatrix[] initialWeights = new RealMatrix[weightsCount]; + for (int i = 0; i < weightsCount; i++) { + int rows = hiddenSize; + int cols; + if (i == 0) { + cols = inputSize; + } else { + cols = initialWeights[i - 1].getRowDimension(); + if (i == weightsCount - 1) { + rows = outputSize; + } + } + double[][] d = new double[rows][cols]; + for (int c = 0; c < cols; c++) { + if (i == weightsCount - 1) { + if (c == 0) { + d[0][c] = 1d; + } else { + d[0][c] = r.nextInt(100) / 101d; + } + } else { + d[0][c] = 0; + } + } + + for (int k = 1; k < rows; k++) { + for (int j = 0; j < cols; j++) { + double val; + if (j == 0) { + val = 1d; + } else { + val = r.nextInt(100) / 101d; + } + d[k][j] = val; + } + } + initialWeights[i] = MatrixUtils.createRealMatrix(d); + } + return initialWeights; + } +} \ No newline at end of file Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java ------------------------------------------------------------------------------ svn:eol-style = native --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org