Author: tommaso
Date: Fri Jan 15 16:25:44 2016
New Revision: 1724846

URL: http://svn.apache.org/viewvc?rev=1724846&view=rev
Log:
added sample API, wv test

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java   
(with props)
    labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java   (with props)
    labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java   
(with props)
Modified:
    
labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java?rev=1724846&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java 
(added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java Fri 
Jan 15 16:25:44 2016
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import java.util.Arrays;
+
+/**
+ * an hot-encoded {@link Sample}
+ */
+public class HotEncodedSample extends Sample {
+
+  private final int vocabularySize;
+
+  public HotEncodedSample(double[] inputs, double[] outputs, int 
vocabularySize) {
+    super(inputs, outputs);
+    this.vocabularySize = vocabularySize;
+  }
+
+  @Override
+  public double[] getInputs() {
+    double[] inputs = new double[1 + this.inputs.length * vocabularySize];
+    inputs[0] = 1d;
+    int i = 1;
+    for (double d : this.inputs) {
+      double[] currentInput = hotEncode((int) d);
+      System.arraycopy(currentInput, 0, inputs, i, currentInput.length);
+      i += vocabularySize;
+    }
+    return inputs;
+  }
+
+  @Override
+  public double[] getOutputs() {
+    double[] outputs = new double[this.outputs.length * vocabularySize];
+    int i = 0;
+    for (double d : this.outputs) {
+      double[] currentOutput = hotEncode((int) d);
+      System.arraycopy(currentOutput, 0, outputs, i, currentOutput.length);
+      i += vocabularySize;
+    }
+    return outputs;
+  }
+
+  private double[] hotEncode(int index) {
+    double[] vector = new double[vocabularySize];
+    Arrays.fill(vector, 0d);
+    vector[index] = 1d;
+    return vector;
+  }
+}

Propchange: 
labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
------------------------------------------------------------------------------
    svn:eol-style = native

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java?rev=1724846&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java Fri Jan 15 
16:25:44 2016
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+/**
+ * a training example
+ */
+public class Sample {
+
+  protected final double[] inputs;
+  protected final double[] outputs;
+
+  public Sample(double[] inputs, double[] outputs) {
+    this.inputs = inputs;
+
+    this.outputs = outputs;
+  }
+
+  /**
+   * get the inputs as a double vector
+   *
+   * @return a double array
+   */
+  public double[] getInputs() {
+    double[] result = new double[inputs.length + 1];
+    result[0] = 1d;
+    System.arraycopy(inputs, 0, result, 1, inputs.length);
+    return result;
+  }
+
+  /**
+   * get the outputs as a double vector
+   *
+   * @return a double array
+   */
+  public double[] getOutputs() {
+    return outputs;
+  }
+
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/Sample.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java?rev=1724846&r1=1724845&r2=1724846&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java
 Fri Jan 15 16:25:44 2016
@@ -61,6 +61,10 @@ public class ShallowFeedForwardNeuralNet
     this.weights = weights;
   }
 
+  public RealMatrix[] getWeights() {
+    return weights;
+  }
+
   private RealMatrix[] createRandomWeights() {
     Random r = new Random();
     int[] layers = new int[configuration.layers.length];
@@ -112,7 +116,8 @@ public class ShallowFeedForwardNeuralNet
    * @return the final cost with the updated weights
    * @throws Exception if SGD fails to converge or any numerical error happens
    */
-  public double learnWeights(double[]... samples) throws Exception {
+  public double learnWeights(Sample... samples) throws Exception {
+
     double newCost;
     int iterations = 0;
 
@@ -126,11 +131,10 @@ public class ShallowFeedForwardNeuralNet
         }
       }
       // current training example
-      double[] sample = samples[iterations % samples.length];
+      Sample sample = samples[iterations % samples.length];
 
-      int outputLayerSize = configuration.layers[configuration.layers.length - 
1];
-      double[] expectedOutput = getSampleOutput(sample, outputLayerSize);
-      double[] input = getSampleInput(sample, outputLayerSize);
+      double[] expectedOutput = sample.getOutputs();
+      double[] input = sample.getInputs();
 
       double[] predictedOutput = predictOutput(input); // TODO : use 
debugOutput to avoid performing it again when calculating derivatives
 
@@ -150,7 +154,7 @@ public class ShallowFeedForwardNeuralNet
       cost = newCost;
 
       // calculate the derivatives to update the parameters
-      RealMatrix[] derivatives = calculateDerivatives(input, expectedOutput, 
sample.length);
+      RealMatrix[] derivatives = calculateDerivatives(input, expectedOutput, 
samples.length);
 
       // update the weights
       weights = getUpdatedWeights(derivatives);
@@ -160,20 +164,6 @@ public class ShallowFeedForwardNeuralNet
     return newCost;
   }
 
-  // --- sample parsing ---
-
-  private double[] getSampleInput(double[] sample, int outputLayerSize) {
-    double[] input = Arrays.copyOfRange(sample, outputLayerSize, 
sample.length);
-    double[] result = new double[input.length + 1];
-    result[0] = 1d;
-    System.arraycopy(input, 0, result, 1, input.length);
-    return result;
-  }
-
-  private double[] getSampleOutput(double[] sample, int outputLayerSize) {
-    return Arrays.copyOfRange(sample, 0, outputLayerSize);
-  }
-
   // --- backpropagation ---
 
   private RealMatrix[] calculateDerivatives(double[] input, double[] output, 
int size) throws Exception {

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java?rev=1724846&r1=1724845&r2=1724846&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java
 Fri Jan 15 16:25:44 2016
@@ -43,19 +43,11 @@ public class ShallowFeedForwardNeuralNet
     ShallowFeedForwardNeuralNetwork neuralNetwork = new 
ShallowFeedForwardNeuralNetwork(configuration);
 
     assertNotNull(neuralNetwork);
-    double[][] samples = new double[3][4];
-    samples[0][0] = 0.1d;
-    samples[0][1] = 0.2d;
-    samples[0][2] = 0.3d;
-    samples[0][3] = 0.4d;
-    samples[1][0] = 0.5d;
-    samples[1][1] = 0.6d;
-    samples[1][2] = 0.7d;
-    samples[1][3] = 0.8d;
-    samples[2][0] = 0.9d;
-    samples[2][1] = 0.1d;
-    samples[2][2] = 0.2d;
-    samples[2][3] = 0.3d;
+    Sample[] samples = new Sample[3];
+    samples[0] = new Sample(new double[]{0.2, 0.3, 0.4}, new double[]{0.1});
+    samples[1] = new Sample(new double[]{0.6, 0.7, 0.8}, new double[]{0.5});
+    samples[2] = new Sample(new double[]{0.1, 0.2, 0.3}, new double[]{0.9});
+
     double cost = neuralNetwork.learnWeights(samples);
     assertTrue(cost > 0 && cost < 10);
 

Added: labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java?rev=1724846&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java 
(added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java Fri 
Jan 15 16:25:44 2016
@@ -0,0 +1,431 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import com.google.common.base.Splitter;
+import org.apache.commons.math3.linear.MatrixUtils;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.junit.Test;
+
+import java.io.BufferedReader;
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.nio.ByteBuffer;
+import java.nio.CharBuffer;
+import java.nio.channels.SeekableByteChannel;
+import java.nio.charset.Charset;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Queue;
+import java.util.Random;
+import java.util.Set;
+import java.util.concurrent.ConcurrentLinkedDeque;
+import java.util.regex.Pattern;
+
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * Integration test for using Yay to implement word vectors algorithms.
+ */
+public class WordVectorsTest {
+
+  private static final boolean measure = false;
+
+  private static final boolean serialize = true;
+
+  @Test
+  public void testSGM() throws Exception {
+
+    Path path = 
Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
+
+    System.out.println("reading fragments");
+    int window = 3;
+    Queue<List<byte[]>> fragments = getFragments(path, window);
+    assertFalse(fragments.isEmpty());
+    System.out.println("generating vocabulary");
+    List<String> vocabulary = getVocabulary(path);
+    assertFalse(vocabulary.isEmpty());
+
+    System.out.println("creating training set");
+    Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, 
fragments, window);
+    fragments.clear();
+    HotEncodedSample next = trainingSet.iterator().next();
+
+    int inputSize = next.getInputs().length - 1;
+    int outputSize = next.getOutputs().length;
+
+    int hiddenSize = 10;
+    System.out.println("initializing neural network");
+
+    ActivationFunction[] activationFunctions = new ActivationFunction[2];
+    activationFunctions[0] = new IdentityActivationFunction();
+    activationFunctions[1] = new SoftmaxActivationFunction();
+
+    ShallowFeedForwardNeuralNetwork.Configuration configuration = new 
ShallowFeedForwardNeuralNetwork.Configuration();
+    configuration.alpha = 0.00001d;
+    configuration.layers = new int[]{inputSize, hiddenSize, outputSize};
+    configuration.maxIterations = trainingSet.size();
+    configuration.threshold = 0.00000004d;
+    configuration.activationFunctions = activationFunctions;
+
+    ShallowFeedForwardNeuralNetwork neuralNetwork = new 
ShallowFeedForwardNeuralNetwork(configuration);
+    System.out.println("learning...");
+    long start = System.currentTimeMillis();
+    neuralNetwork.learnWeights(trainingSet.toArray(new 
Sample[trainingSet.size()]));
+    RealMatrix[] learnedWeights = neuralNetwork.getWeights();
+    System.out.println("learning finished in " + (System.currentTimeMillis() - 
start) / 60000 + " minutes");
+
+    RealMatrix wordVectors = learnedWeights[0];
+
+    assertNotNull(wordVectors);
+
+    if (serialize) {
+      serialize(vocabulary, wordVectors);
+    }
+
+    if (measure) {
+      measure(vocabulary, wordVectors);
+    }
+  }
+
+  private void measure(List<String> vocabulary, RealMatrix wordVectors) {
+    System.out.println("measuring similarities");
+    Collection<DistanceMeasure> measures = new LinkedList<>();
+    measures.add(new EuclideanDistance());
+//    measures.add(new DistanceMeasure() {
+//      @Override
+//      public double compute(double[] a, double[] b) {
+//        double dp = 0.0;
+//        double na = 0.0;
+//        double nb = 0.0;
+//        for (int i = 0; i < a.length; i++) {
+//          dp += a[i] * b[i];
+//          na += Math.pow(a[i], 2);
+//          nb += Math.pow(b[i], 2);
+//        }
+//        double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
+//        return 1 / cosineSimilarity;
+//      }
+//
+//      @Override
+//      public String toString() {
+//        return "inverse cosine similarity distance measure";
+//      }
+//    });
+//    measures.add((DistanceMeasure) (a, b) -> {
+//      double da = 
FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a)));
+//      double db = 
FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b)));
+//      return Math.abs(db - da);
+//    });
+    for (DistanceMeasure distanceMeasure : measures) {
+      System.out.println("computing similarity using " + distanceMeasure);
+      computeSimilarities(vocabulary, wordVectors, distanceMeasure);
+    }
+  }
+
+  private void serialize(List<String> vocabulary, RealMatrix wordVectors) 
throws IOException {
+    System.out.println("serializing word vectors");
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File("target/sg-vectors.csv")));
+    for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+      double[] a = wordVectors.getColumnVector(i).toArray();
+      String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
+      csq = csq.substring(1, csq.length() - 1);
+      bufferedWriter.append(csq);
+      bufferedWriter.append(", ");
+      bufferedWriter.append(vocabulary.get(i - 1));
+      bufferedWriter.newLine();
+    }
+    bufferedWriter.flush();
+    bufferedWriter.close();
+
+    // for post processing with dimensionality reduction (PCA, t-SNE, etc.):
+    // values: awk '{$hiddenSize=""; print $0}' target/sg-vectors.csv
+    // keys: awk '{print $hiddenSize}' target/sg-vectors.csv
+  }
+
+  private void computeSimilarities(List<String> vocabulary, RealMatrix 
wordVectors, DistanceMeasure distanceMeasure) {
+    for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
+      double[] subjectVector = wordVectors.getColumn(i);
+      subjectVector = Arrays.copyOfRange(subjectVector, 1, 
subjectVector.length);
+      double maxSimilarity = -Double.MAX_VALUE;
+      double maxSimilarity1 = -Double.MAX_VALUE;
+      double maxSimilarity2 = -Double.MAX_VALUE;
+      int j0 = -1;
+      int j1 = -1;
+      int j2 = -1;
+      for (int j = 1; j < wordVectors.getColumnDimension(); j++) {
+        if (i != j) {
+          double[] vector = wordVectors.getColumn(j);
+          vector = Arrays.copyOfRange(vector, 1, vector.length);
+          double similarity = 1d / distanceMeasure.compute(subjectVector, 
vector);
+          if (similarity > maxSimilarity) {
+            maxSimilarity2 = maxSimilarity1;
+            j2 = j1;
+
+            maxSimilarity1 = maxSimilarity;
+            j1 = j0;
+
+            maxSimilarity = similarity;
+            j0 = j;
+          } else if (similarity > maxSimilarity1) {
+            maxSimilarity2 = maxSimilarity1;
+            j2 = j1;
+
+            maxSimilarity1 = similarity;
+            j1 = j;
+          } else if (similarity > maxSimilarity2) {
+            maxSimilarity2 = similarity;
+            j2 = j;
+          }
+        }
+      }
+      if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) {
+        System.out.println(vocabulary.get(i - 1) + " -> "
+                + vocabulary.get(j0 - 1) + ", "
+                + vocabulary.get(j1 - 1) + ", "
+                + vocabulary.get(j2 - 1));
+      } else {
+        System.err.println("no similarity for '" + vocabulary.get(i) + "' with 
" + distanceMeasure);
+      }
+    }
+  }
+
+  private Collection<HotEncodedSample> createTrainingSet(final List<String> 
vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException {
+    long start = System.currentTimeMillis();
+    Collection<HotEncodedSample> samples = new LinkedList<>();
+    List<byte[]> fragment;
+    while ((fragment = fragments.poll()) != null) {
+      byte[] inputWord = null;
+      List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1);
+      for (int i = 0; i < fragment.size(); i++) {
+        for (int j = 0; j < fragment.size(); j++) {
+          byte[] token = fragment.get(i);
+          if (i == j) {
+            inputWord = token;
+          } else {
+            outputWords.add(token);
+          }
+        }
+      }
+      final byte[] finalInputWord = inputWord;
+
+      double[] doubles = new double[window - 1];
+      for (int i = 0; i < doubles.length; i++) {
+        doubles[i] = (double) vocabulary.indexOf(new 
String(outputWords.get(i)));
+      }
+
+      double[] inputs = new double[1];
+      inputs[0] = (double) vocabulary.indexOf(new String(finalInputWord));
+
+      samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size()));
+
+    }
+
+    long end = System.currentTimeMillis();
+    System.out.println("training set created in " + (end - start) / 60000 + " 
minutes");
+
+    return samples;
+  }
+
+  public static double[] hotEncode(int index, int size) {
+    double[] vector = new double[size];
+    Arrays.fill(vector, 0d);
+    vector[index] = 1d;
+    return vector;
+  }
+
+
+  private List<String> getVocabulary(Path path) throws IOException {
+    Set<String> vocabulary = new HashSet<>();
+    ByteBuffer buf = ByteBuffer.allocate(100);
+    try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
+
+      String encoding = System.getProperty("file.encoding");
+      StringBuilder previous = new StringBuilder();
+      Splitter splitter = 
Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults();
+      while (sbc.read(buf) > 0) {
+        buf.rewind();
+        CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+        String string = cleanString(charBuffer);
+        List<String> split = splitter.splitToList(string);
+        int splitSize = split.size();
+        if (splitSize > 1) {
+          String term = previous.append(split.get(0)).toString();
+          vocabulary.add(term.intern());
+          for (int i = 1; i < splitSize - 1; i++) {
+            String term2 = split.get(i);
+            vocabulary.add(term2.intern());
+          }
+          previous = new StringBuilder().append(split.get(splitSize - 1));
+        } else if (split.size() == 1) {
+          previous.append(string);
+        }
+        buf.flip();
+      }
+    } catch (IOException x) {
+      System.err.println("caught exception: " + x);
+    } finally {
+      buf.clear();
+    }
+    List<String> list = Arrays.asList(vocabulary.toArray(new 
String[vocabulary.size()]));
+    Collections.sort(list);
+//    for (String iw : vocabulary) {
+//      System.out.println(iw 
+"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list)));
+//    }
+    return list;
+  }
+
+  private String cleanString(CharBuffer charBuffer) {
+    String s = charBuffer.toString();
+    return s.toLowerCase().replaceAll("\\.", " ").replaceAll("\\;", " 
").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", 
"").replaceAll("\\\"", "");
+  }
+
+  private List<String> getVocabulary(Collection<byte[]> sentences) {
+    long start = System.currentTimeMillis();
+    List<String> vocabulary = new LinkedList<>();
+    for (byte[] sentence : sentences) {
+      for (String token : new String(sentence).split(" ")) {
+        if (!vocabulary.contains(token)) {
+          vocabulary.add(token);
+        }
+      }
+    }
+    System.out.println("sorting vocabulary");
+    Collections.sort(vocabulary);
+    long end = System.currentTimeMillis();
+    System.out.println("vocabulary generated in " + (end - start) / 60000 + " 
minutes");
+    return vocabulary;
+  }
+
+  private Queue<List<byte[]>> getFragments(Path path, int w) throws 
IOException {
+    long start = System.currentTimeMillis();
+    Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>();
+
+    ByteBuffer buf = ByteBuffer.allocate(100);
+    try (SeekableByteChannel sbc = Files.newByteChannel(path)) {
+
+      String encoding = System.getProperty("file.encoding");
+      StringBuilder previous = new StringBuilder();
+      Splitter splitter = 
Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults();
+      while (sbc.read(buf) > 0) {
+        buf.rewind();
+        CharBuffer charBuffer = Charset.forName(encoding).decode(buf);
+        String string = cleanString(charBuffer);
+        List<String> split = splitter.splitToList(string);
+        int splitSize = split.size();
+        if (splitSize > w) {
+          for (int j = 0; j < splitSize - w; j++) {
+            List<byte[]> fragment = new ArrayList<>(w);
+            fragment.add(previous.append(split.get(j)).toString().getBytes());
+            for (int i = 1; i < w; i++) {
+              fragment.add(split.get(i + j).getBytes());
+            }
+            // TODO : this has to be used to re-use the tokens that have not 
been consumed in next iteration
+            fragments.add(fragment);
+            previous = new StringBuilder();
+          }
+          previous = new StringBuilder().append(split.get(splitSize - 1));
+        } else if (split.size() == w) {
+          previous.append(string);
+        }
+        buf.flip();
+      }
+    } catch (IOException x) {
+      System.err.println("caught exception: " + x);
+    } finally {
+      buf.clear();
+    }
+    long end = System.currentTimeMillis();
+    System.out.println("fragments read in " + (end - start) / 60000 + " 
minutes (" + fragments.size() + ")");
+    return fragments;
+  }
+
+  private Collection<String> getSentences() throws IOException {
+    Collection<String> sentences = new LinkedList<>();
+
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/test.txt");
+    BufferedReader bufferedReader = new BufferedReader(new 
InputStreamReader(resourceAsStream));
+    String line;
+    while ((line = bufferedReader.readLine()) != null) {
+      String cleanLine = line.toLowerCase().replaceAll("\\.", " 
").replaceAll(";", " ").replaceAll(",", " ").replaceAll(":", " 
").replaceAll("-", "");
+      sentences.add(cleanLine);
+    }
+    return sentences;
+  }
+
+  private RealMatrix[] createRandomWeights(int inputSize, int hiddenSize, int 
outputSize) {
+    Random r = new Random();
+    int weightsCount = 2;
+
+    RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+    for (int i = 0; i < weightsCount; i++) {
+      int rows = hiddenSize;
+      int cols;
+      if (i == 0) {
+        cols = inputSize;
+      } else {
+        cols = initialWeights[i - 1].getRowDimension();
+        if (i == weightsCount - 1) {
+          rows = outputSize;
+        }
+      }
+      double[][] d = new double[rows][cols];
+      for (int c = 0; c < cols; c++) {
+        if (i == weightsCount - 1) {
+          if (c == 0) {
+            d[0][c] = 1d;
+          } else {
+            d[0][c] = r.nextInt(100) / 101d;
+          }
+        } else {
+          d[0][c] = 0;
+        }
+      }
+
+      for (int k = 1; k < rows; k++) {
+        for (int j = 0; j < cols; j++) {
+          double val;
+          if (j == 0) {
+            val = 1d;
+          } else {
+            val = r.nextInt(100) / 101d;
+          }
+          d[k][j] = val;
+        }
+      }
+      initialWeights[i] = MatrixUtils.createRealMatrix(d);
+    }
+    return initialWeights;
+  }
+}
\ No newline at end of file

Propchange: 
labs/yay/trunk/core/src/test/java/org/apache/yay/WordVectorsTest.java
------------------------------------------------------------------------------
    svn:eol-style = native



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to