Author: tommaso
Date: Thu Feb 25 11:17:56 2016
New Revision: 1732287

URL: http://svn.apache.org/viewvc?rev=1732287&view=rev
Log:
refactored MLN and improved SGN, commons-math 3.6

Modified:
    labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
    labs/yay/trunk/core/src/test/resources/word2vec/test.txt
    labs/yay/trunk/pom.xml

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java?rev=1732287&r1=1732286&r2=1732287&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/HotEncodedSample.java Thu 
Feb 25 11:17:56 2016
@@ -34,9 +34,8 @@ public class HotEncodedSample extends Sa
 
   @Override
   public double[] getInputs() {
-    double[] inputs = new double[1 + this.inputs.length * vocabularySize];
-    inputs[0] = 1d;
-    int i = 1;
+    double[] inputs = new double[this.inputs.length * vocabularySize];
+    int i = 0;
     for (double d : this.inputs) {
       double[] currentInput = hotEncode((int) d);
       System.arraycopy(currentInput, 0, inputs, i, currentInput.length);

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1732287&r1=1732286&r2=1732287&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Thu 
Feb 25 11:17:56 2016
@@ -24,6 +24,7 @@ import org.apache.commons.math3.linear.M
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealMatrixChangingVisitor;
 import org.apache.commons.math3.linear.RealMatrixPreservingVisitor;
+import org.apache.commons.math3.linear.RealVector;
 
 import java.io.IOException;
 import java.nio.ByteBuffer;
@@ -68,11 +69,44 @@ public class SkipGramNetwork {
    * the second row of weighs[0] holds the weights of each neuron in the 
second neuron of the second layer, etc.
    */
   private RealMatrix[] weights;
+  private RealMatrix[] biases;
+  private Sample[] samples;
 
 
   private SkipGramNetwork(Configuration configuration) {
     this.configuration = configuration;
     this.weights = createRandomWeights();
+    this.biases = createRandomBiases();
+  }
+
+  private RealMatrix[] createRandomBiases() {
+    Random r = new Random();
+
+    RealMatrix[] initialWeights = new RealMatrix[weights.length];
+
+    for (int i = 0; i < initialWeights.length; i++) {
+
+      RealMatrix matrix = MatrixUtils.createRealMatrix(1, 
weights[i].getRowDimension());
+      matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return 1;//r.nextInt(100000) / 10000001d;
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      initialWeights[i] = matrix;
+    }
+    return initialWeights;
   }
 
   public RealMatrix[] getWeights() {
@@ -83,13 +117,21 @@ public class SkipGramNetwork {
     return configuration.vocabulary;
   }
 
+  public double[] predictOutput(double[] input) throws Exception {
+
+    RealMatrix hidden = 
rectifierFunction.applyMatrix(MatrixUtils.createRowRealMatrix(input).multiply(weights[0].transpose()).
+            add(biases[0]));
+    RealMatrix pscores = 
softmaxActivationFunction.applyMatrix(hidden.multiply(weights[1].transpose()).add(biases[1]));
+
+    RealVector d = pscores.getRowVector(0);
+    return d.toArray();
+  }
+
   private RealMatrix[] createRandomWeights() {
     Random r = new Random();
     int[] conf = new int[]{configuration.inputs, configuration.vectorSize, 
configuration.outputs};
     int[] layers = new int[conf.length];
-    for (int i = 0; i < layers.length; i++) {
-      layers[i] = conf[i] + (i < layers.length - 1 ? 1 : 0);
-    }
+    System.arraycopy(conf, 0, layers, 0, layers.length);
     int weightsCount = layers.length - 1;
 
     RealMatrix[] initialWeights = new RealMatrix[weightsCount];
@@ -97,7 +139,6 @@ public class SkipGramNetwork {
     for (int i = 0; i < weightsCount; i++) {
 
       RealMatrix matrix = MatrixUtils.createRealMatrix(layers[i + 1], 
layers[i]);
-      final int finalI = i;
       matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
         @Override
         public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
@@ -106,12 +147,7 @@ public class SkipGramNetwork {
 
         @Override
         public double visit(int row, int column, double value) {
-          if (finalI != weightsCount - 1 && row == 0) {
-            return 0d;
-          } else if (column == 0) {
-            return 1d;
-          }
-          return r.nextInt(100) / 101d;
+          return r.nextInt(10) / 1000000001d;
         }
 
         @Override
@@ -126,6 +162,50 @@ public class SkipGramNetwork {
 
   }
 
+  private void evaluate() throws Exception {
+    double cc = 0;
+    double wc = 0;
+    for (Sample sample : samples) {
+      int window = configuration.window;
+      Collection<Integer> exps = new ArrayList<>(window - 1);
+      Collection<Integer> acts = new ArrayList<>(window - 1);
+      double[] inputs = sample.getInputs();
+      double[] actualOutputs = predictOutput(inputs);
+      double[] expectedOutputs = sample.getOutputs();
+      int j = 0;
+      for (int i = 0; i < window - 1; i++) {
+        int actualMax = getMaxIndex(actualOutputs, j, j + inputs.length - 1);
+        int expectedMax = getMaxIndex(expectedOutputs, j, j + inputs.length - 
1);
+        exps.add(expectedMax);
+        acts.add(actualMax);
+        j += i + inputs.length - 2;
+      }
+      boolean c = true;
+      for (Integer a : acts) {
+        c &= exps.contains(a);
+      }
+      if (c) {
+        cc++;
+      } else {
+        wc++;
+      }
+
+    }
+    System.out.println("accuracy: " + (cc / (wc + cc)));
+  }
+
+  private int getMaxIndex(double[] array, int start, int end) {
+    double largest = array[start];
+    int index = 0;
+    for (int i = start + 1; i < end; i++) {
+      if (array[i] >= largest) {
+        largest = array[i];
+        index = i;
+      }
+    }
+    return index;
+  }
+
   // --- batch gradient descent ---
 
   /**
@@ -140,29 +220,65 @@ public class SkipGramNetwork {
     int iterations = 0;
 
     double cost = Double.MAX_VALUE;
+
+    RealMatrix x = MatrixUtils.createRealMatrix(samples.length, 
samples[0].getInputs().length);
+    RealMatrix y = MatrixUtils.createRealMatrix(samples.length, 
samples[0].getOutputs().length);
+    int i = 0;
+    for (Sample sample : samples) {
+      x.setRow(i, ArrayUtils.addAll(sample.getInputs()));
+      y.setRow(i, ArrayUtils.addAll(sample.getOutputs()));
+      i++;
+    }
+
     long start = System.currentTimeMillis();
     while (true) {
-      if (iterations % (1 + (configuration.maxIterations / 100)) == 0) {
-        long time = (System.currentTimeMillis() - start) / 1000;
+      long time = (System.currentTimeMillis() - start) / 1000;
+      if (iterations % (1 + (configuration.maxIterations / 100)) == 0 || time 
% 300 < 2) {
         if (time > 60) {
           System.out.println("cost is " + cost + " after " + iterations + " 
iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " 
ips)");
         }
       }
-      double newCost = 0;
-      RealMatrix x = MatrixUtils.createRealMatrix(samples.length, 
samples[0].getInputs().length);
-      RealMatrix y = MatrixUtils.createRealMatrix(samples.length, 
samples[0].getOutputs().length);
-      int i = 0;
-      for (Sample sample : samples) {
-        x.setRow(i, ArrayUtils.addAll(sample.getInputs()));
-        y.setRow(i, ArrayUtils.addAll(sample.getOutputs()));
-        i++;
+      if (iterations % 1000 == 0) {
+        evaluate();
       }
 
       RealMatrix w0t = weights[0].transpose();
       final RealMatrix w1t = weights[1].transpose();
 
       RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(w0t));
+      hidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return value + biases[0].getEntry(0, column);
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
       RealMatrix scores = hidden.multiply(w1t);
+      scores.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return value + biases[1].getEntry(0, column);
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
 
       RealMatrix probs = softmaxActivationFunction.applyMatrix(scores);
 
@@ -221,11 +337,34 @@ public class SkipGramNetwork {
           return d;
         }
       });
-      newCost = dataLoss + 0.5 * 0.03 * reg;
+      reg += weights[1].walkInOptimizedOrder(new RealMatrixPreservingVisitor() 
{
+        private double d = 0d;
+
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public void visit(int row, int column, double value) {
+          d += Math.pow(value, 2);
+        }
+
+        @Override
+        public double end() {
+          return d;
+        }
+      });
+
+      double regLoss = 0.5 * configuration.regularizationLambda * reg;
+      double newCost = dataLoss + regLoss;
+      if (iterations == 0) {
+        System.out.println("started with cost = " + dataLoss + " + " + 
regLoss);
+      }
 
       if (Double.POSITIVE_INFINITY == newCost || newCost > cost) {
         throw new Exception("failed to converge at iteration " + iterations + 
" with alpha " + configuration.alpha + " : cost going from " + cost + " to " + 
newCost);
-      } else if (iterations > 1 && (newCost < configuration.threshold || 
iterations > configuration.maxIterations)) {
+      } else if (iterations > 1 && (newCost < configuration.threshold || 
iterations > configuration.maxIterations || cost - newCost < 
configuration.threshold)) {
         cost = newCost;
         System.out.println("successfully converged after " + (iterations - 1) 
+ " iterations (alpha:" + configuration.alpha + ",threshold:" + 
configuration.threshold + ") with cost " + newCost);
         break;
@@ -238,7 +377,7 @@ public class SkipGramNetwork {
 
       // calculate the derivatives to update the parameters
 
-      RealMatrix dscores = probs;
+      RealMatrix dscores = probs.copy();
       dscores.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
         @Override
         public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
@@ -267,11 +406,25 @@ public class SkipGramNetwork {
 
         @Override
         public double visit(int row, int column, double value) {
-          if (column != 0) {
-            return value + 0.3 * w1t.getEntry(row, column);
-          } else {
-            return value;
-          }
+          return value + configuration.regularizationLambda * 
w1t.getEntry(row, column);
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      RealMatrix db2 = 
MatrixUtils.createRealMatrix(biases[1].getRowDimension(), 
biases[1].getColumnDimension());
+      dscores.walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public void visit(int row, int column, double value) {
+          db2.setEntry(0, column, db2.getEntry(0, column) + value);
         }
 
         @Override
@@ -281,6 +434,40 @@ public class SkipGramNetwork {
       });
 
       RealMatrix dhidden = dscores.multiply(weights[1]);
+      dhidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return value < 0 ? 0 : value;
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      RealMatrix db = 
MatrixUtils.createRealMatrix(biases[0].getRowDimension(), 
biases[0].getColumnDimension());
+      dhidden.walkInOptimizedOrder(new RealMatrixPreservingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public void visit(int row, int column, double value) {
+          db.setEntry(0, column, db.getEntry(0, column) + value);
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
 
       RealMatrix dW = x.transpose().multiply(dhidden);
       dW.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
@@ -291,11 +478,42 @@ public class SkipGramNetwork {
 
         @Override
         public double visit(int row, int column, double value) {
-          if (column != 0) {
-            return value + 0.03 * w0t.getEntry(row, column);
-          } else {
-            return value;
-          }
+          return value + configuration.regularizationLambda * 
w0t.getEntry(row, column);
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      // update bias
+      biases[0].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return value - configuration.alpha * db.getEntry(row, column);
+        }
+
+        @Override
+        public double end() {
+          return 0;
+        }
+      });
+
+      biases[1].walkInOptimizedOrder(new RealMatrixChangingVisitor() {
+        @Override
+        public void start(int rows, int columns, int startRow, int endRow, int 
startColumn, int endColumn) {
+
+        }
+
+        @Override
+        public double visit(int row, int column, double value) {
+          return value - configuration.alpha * db2.getEntry(row, column);
         }
 
         @Override
@@ -307,10 +525,7 @@ public class SkipGramNetwork {
       RealMatrix[] derivatives = new RealMatrix[]{dW.transpose(), 
dW2.transpose()};
 
       // update the weights
-      RealMatrix[] updatedParameters = new RealMatrix[weights.length];
-
       for (int l = 0; l < weights.length; l++) {
-        RealMatrix realMatrix = weights[l].copy();
         final int finalL = l;
         RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() {
 
@@ -321,11 +536,7 @@ public class SkipGramNetwork {
 
           @Override
           public double visit(int row, int column, double value) {
-            if (!(row == 0 && value == 0d) && !(column == 0 && value == 1d)) {
               return value - configuration.alpha * 
derivatives[finalL].getEntry(row, column);
-            } else {
-              return value;
-            }
           }
 
           @Override
@@ -333,10 +544,8 @@ public class SkipGramNetwork {
             return 0;
           }
         };
-        realMatrix.walkInOptimizedOrder(visitor);
-        updatedParameters[l] = realMatrix;
+        weights[l].walkInOptimizedOrder(visitor);
       }
-      weights = updatedParameters;
 
       iterations++;
     }
@@ -359,6 +568,10 @@ public class SkipGramNetwork {
     return new Builder();
   }
 
+  public Sample[] getSamples() {
+    return samples;
+  }
+
   // --- skip gram neural network configuration ---
 
   private static class Configuration {
@@ -371,8 +584,9 @@ public class SkipGramNetwork {
     // user controlled parameters
     protected Path path;
     protected int maxIterations;
-    protected double alpha = 0.003d;
-    protected double threshold = 0.004d;
+    protected double alpha = 0.0001d;
+    protected double regularizationLambda = 0.000000000003;
+    protected double threshold = 0.0000000000004d;
     protected int vectorSize;
     protected int window;
   }
@@ -405,43 +619,53 @@ public class SkipGramNetwork {
       Queue<List<byte[]>> fragments = getFragments(this.configuration.path, 
this.configuration.window);
       assert !fragments.isEmpty() : "could not read fragments";
       System.out.println("generating vocabulary");
-      List<String> vocabulary = getVocabulary(this.configuration.path);
+//      List<String> vocabulary = getVocabulary(this.configuration.path);
+      List<String> vocabulary = getVocabulary(fragments);
       assert !vocabulary.isEmpty() : "could not read vocabulary";
       this.configuration.vocabulary = vocabulary;
 
       System.out.println("creating training set");
       Collection<HotEncodedSample> trainingSet = createTrainingSet(vocabulary, 
fragments, this.configuration.window);
       fragments.clear();
-      this.configuration.maxIterations = trainingSet.size() * 10;
+      this.configuration.maxIterations = trainingSet.size() * 100000;
 
       HotEncodedSample next = trainingSet.iterator().next();
 
-      this.configuration.inputs = next.getInputs().length - 1;
+      this.configuration.inputs = next.getInputs().length;
       this.configuration.outputs = next.getOutputs().length;
 
       SkipGramNetwork network = new SkipGramNetwork(configuration);
-      network.learnWeights(trainingSet.toArray(new 
Sample[trainingSet.size()]));
+      network.samples = trainingSet.toArray(new Sample[trainingSet.size()]);
+      network.learnWeights(network.samples);
       return network;
     }
 
+    private List<String> getVocabulary(Queue<List<byte[]>> fragments) {
+      List<String> vocabulary = new LinkedList<>();
+      for (List<byte[]> fragment : fragments) {
+        for (byte[] word : fragment) {
+          String s = new String(word);
+          if (!vocabulary.contains(s)) {
+            vocabulary.add(s);
+          }
+        }
+      }
+      return vocabulary;
+    }
+
     private Collection<HotEncodedSample> createTrainingSet(final List<String> 
vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException {
       long start = System.currentTimeMillis();
       Collection<HotEncodedSample> samples = new LinkedList<>();
       List<byte[]> fragment;
       while ((fragment = fragments.poll()) != null) {
-        byte[] inputWord = null;
         List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1);
-        for (int i = 0; i < fragment.size(); i++) {
-          for (int j = 0; j < fragment.size(); j++) {
-            byte[] token = fragment.get(i);
-            if (i == j) {
-              inputWord = token;
-            } else {
-              outputWords.add(token);
-            }
+        int inputIdx = fragment.size() / 2;
+        byte[] inputWord = fragment.get(inputIdx);
+        for (int k = 0; k < fragment.size(); k++) {
+          if (k != inputIdx) {
+            outputWords.add(fragment.get(k));
           }
         }
-        final byte[] finalInputWord = inputWord;
 
         double[] doubles = new double[window - 1];
         for (int i = 0; i < doubles.length; i++) {
@@ -449,7 +673,7 @@ public class SkipGramNetwork {
         }
 
         double[] inputs = new double[1];
-        inputs[0] = (double) vocabulary.indexOf(new String(finalInputWord));
+        inputs[0] = (double) vocabulary.indexOf(new String(inputWord));
 
         samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size()));
 
@@ -480,9 +704,11 @@ public class SkipGramNetwork {
           if (splitSize > w) {
             for (int j = 0; j < splitSize - w; j++) {
               List<byte[]> fragment = new ArrayList<>(w);
-              
fragment.add(previous.append(split.get(j)).toString().getBytes());
+              String str = split.get(j);
+              fragment.add(previous.append(str).toString().getBytes());
               for (int i = 1; i < w; i++) {
-                fragment.add(split.get(i + j).getBytes());
+                String s = split.get(i + j);
+                fragment.add(s.getBytes());
               }
               // TODO : this has to be used to re-use the tokens that have not 
been consumed in next iteration
               fragments.add(fragment);
@@ -546,7 +772,7 @@ public class SkipGramNetwork {
 
     private String cleanString(CharBuffer charBuffer) {
       String s = charBuffer.toString();
-      return s.toLowerCase().replaceAll("\\.", " ").replaceAll("\\;", " 
").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", 
"").replaceAll("\\\"", "");
+      return s.toLowerCase().replaceAll("\\.", " ");//.replaceAll("\\;", " 
").replaceAll("\\,", " ").replaceAll("\\:", " ").replaceAll("\\-\\s", 
"").replaceAll("\\\"", "");
     }
   }
 }
\ No newline at end of file

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java?rev=1732287&r1=1732286&r2=1732287&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java 
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java 
Thu Feb 25 11:17:56 2016
@@ -149,7 +149,6 @@ public class MultiLayerNetworkTest {
 
     MultiLayerNetwork nor = new MultiLayerNetwork(configuration, 
norRealMatrixSet);
 
-
     assertEquals(0L, Math.round(nor.predictOutput(new double[]{1d, 0d})[0]));
     assertEquals(0L, Math.round(nor.predictOutput(new double[]{0d, 1d})[0]));
     assertEquals(1L, Math.round(nor.predictOutput(new double[]{0d, 0d})[0]));

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1732287&r1=1732286&r2=1732287&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java 
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java 
Thu Feb 25 11:17:56 2016
@@ -18,9 +18,14 @@
  */
 package org.apache.yay;
 
+import org.apache.commons.math3.linear.MatrixUtils;
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.ml.distance.CanberraDistance;
+import org.apache.commons.math3.ml.distance.ChebyshevDistance;
 import org.apache.commons.math3.ml.distance.DistanceMeasure;
 import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.ml.distance.ManhattanDistance;
+import org.apache.commons.math3.util.FastMath;
 import org.junit.Test;
 
 import java.io.BufferedWriter;
@@ -29,8 +34,10 @@ import java.io.FileWriter;
 import java.io.IOException;
 import java.nio.file.Path;
 import java.nio.file.Paths;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
+import java.util.Date;
 import java.util.LinkedList;
 import java.util.List;
 
@@ -42,64 +49,76 @@ public class SkipGramNetworkTest {
   @Test
   public void testWordVectorsLearningOnAbstracts() throws Exception {
     Path path = 
Paths.get(getClass().getResource("/word2vec/abstracts.txt").getFile());
-    SkipGramNetwork network = 
SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+    int window = 3;
+    SkipGramNetwork network = 
SkipGramNetwork.newModel().withWindow(window).fromTextAt(path).withDimension(10).build();
     RealMatrix wv = network.getWeights()[0];
     List<String> vocabulary = network.getVocabulary();
     serialize(vocabulary, wv);
-    measure(vocabulary, wv);
+    evaluate(network, window);
   }
 
   @Test
   public void testWordVectorsLearningOnSentences() throws Exception {
     Path path = 
Paths.get(getClass().getResource("/word2vec/sentences.txt").getFile());
-    SkipGramNetwork network = 
SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+    int window = 3;
+    SkipGramNetwork network = 
SkipGramNetwork.newModel().withWindow(window).fromTextAt(path).withDimension(10).build();
     RealMatrix wv = network.getWeights()[0];
     List<String> vocabulary = network.getVocabulary();
     serialize(vocabulary, wv);
-    measure(vocabulary, wv);
+    evaluate(network, window);
   }
 
   @Test
   public void testWordVectorsLearningOnTestData() throws Exception {
     Path path = 
Paths.get(getClass().getResource("/word2vec/test.txt").getFile());
-    SkipGramNetwork network = 
SkipGramNetwork.newModel().withWindow(4).fromTextAt(path).withDimension(10).build();
+    int window = 3;
+    SkipGramNetwork network = 
SkipGramNetwork.newModel().withWindow(window).fromTextAt(path).withDimension(10).build();
+    evaluate(network, window);
+    network.learnWeights(network.getSamples());
+    evaluate(network, window);
     RealMatrix wv = network.getWeights()[0];
     List<String> vocabulary = network.getVocabulary();
     serialize(vocabulary, wv);
-    measure(vocabulary, wv);
   }
 
   private void measure(List<String> vocabulary, RealMatrix wordVectors) {
     System.out.println("measuring similarities");
     Collection<DistanceMeasure> measures = new LinkedList<>();
     measures.add(new EuclideanDistance());
-//    measures.add(new DistanceMeasure() {
-//      @Override
-//      public double compute(double[] a, double[] b) {
-//        double dp = 0.0;
-//        double na = 0.0;
-//        double nb = 0.0;
-//        for (int i = 0; i < a.length; i++) {
-//          dp += a[i] * b[i];
-//          na += Math.pow(a[i], 2);
-//          nb += Math.pow(b[i], 2);
-//        }
-//        double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
-//        return 1 / cosineSimilarity;
-//      }
-//
-//      @Override
-//      public String toString() {
-//        return "inverse cosine similarity distance measure";
-//      }
-//    });
-//    measures.add((DistanceMeasure) (a, b) -> {
-//      double da = 
FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a)));
-//      double db = 
FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b)));
-//      return Math.abs(db - da);
-//    });
+    measures.add(new CanberraDistance());
+    measures.add(new ChebyshevDistance());
+    measures.add(new ManhattanDistance());
+    measures.add(new DistanceMeasure() {
+      @Override
+      public double compute(double[] a, double[] b) {
+        double dp = 0.0;
+        double na = 0.0;
+        double nb = 0.0;
+        for (int i = 0; i < a.length; i++) {
+          dp += a[i] * b[i];
+          na += Math.pow(a[i], 2);
+          nb += Math.pow(b[i], 2);
+        }
+        double cosineSimilarity = dp / (Math.sqrt(na) * Math.sqrt(nb));
+        return 1 / cosineSimilarity;
+      }
+
+      @Override
+      public String toString() {
+        return "inverse cosine similarity distance measure";
+      }
+    });
+    measures.add((DistanceMeasure) (a, b) -> {
+      double da = 
FastMath.sqrt(MatrixUtils.createRealVector(a).dotProduct(MatrixUtils.createRealVector(a)));
+      double db = 
FastMath.sqrt(MatrixUtils.createRealVector(b).dotProduct(MatrixUtils.createRealVector(b)));
+      return Math.abs(db - da);
+    });
     for (DistanceMeasure distanceMeasure : measures) {
-      System.out.println("computing similarity using " + distanceMeasure);
+      System.out.println("*********************************************");
+      System.out.println("*********************************************");
+      System.out.println("*** similarity by " + distanceMeasure + "***");
+      System.out.println("*********************************************");
+      System.out.println("*********************************************");
       computeSimilarities(vocabulary, wordVectors, distanceMeasure);
     }
 
@@ -107,7 +126,7 @@ public class SkipGramNetworkTest {
 
   private void serialize(List<String> vocabulary, RealMatrix wordVectors) 
throws IOException {
     System.out.println("serializing word vectors");
-    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File("target/sg-vectors.csv")));
+    BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter(new 
File("target/sg-vectors-" + new Date().toString() + ".csv")));
     for (int i = 1; i < wordVectors.getColumnDimension(); i++) {
       double[] a = wordVectors.getColumnVector(i).toArray();
       String csq = Arrays.toString(Arrays.copyOfRange(a, 1, a.length));
@@ -163,12 +182,58 @@ public class SkipGramNetworkTest {
       }
       if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) {
         System.out.println(vocabulary.get(i - 1) + " -> "
-                + vocabulary.get(j0 - 1) + ", "
-                + vocabulary.get(j1 - 1) + ", "
-                + vocabulary.get(j2 - 1));
+                        + vocabulary.get(j0 - 1)
+//                + ", "
+//                + vocabulary.get(j1 - 1)
+//                + ", "
+//                + vocabulary.get(j2 - 1)
+        );
       } else {
         System.err.println("no similarity for '" + vocabulary.get(i) + "' with 
" + distanceMeasure);
       }
     }
   }
+
+  private void evaluate(SkipGramNetwork network, int window) throws Exception {
+    double cc = 0;
+    double wc = 0;
+    for (Sample sample : network.getSamples()) {
+      Collection<Integer> exps = new ArrayList<>(window - 1);
+      Collection<Integer> acts = new ArrayList<>(window - 1);
+      double[] inputs = sample.getInputs();
+      double[] actualOutputs = network.predictOutput(inputs);
+      double[] expectedOutputs = sample.getOutputs();
+      int j = 0;
+      for (int i = 0; i < window - 1; i++) {
+        int actualMax = getMaxIndex(actualOutputs, j, j + inputs.length - 1);
+        int expectedMax = getMaxIndex(expectedOutputs, j, j + inputs.length - 
1);
+        exps.add(expectedMax);
+        acts.add(actualMax);
+        j += i + inputs.length - 2;
+      }
+      boolean c = true;
+      for (Integer a : acts) {
+        c &= exps.contains(a);
+      }
+      if (c) {
+        cc++;
+      } else {
+        wc++;
+      }
+    }
+    System.out.println("accuracy: " + (cc / (wc + cc)));
+  }
+
+  private int getMaxIndex(double[] array, int start, int end) {
+    double largest = array[start];
+    int index = 0;
+    for (int i = start + 1; i < end; i++) {
+      if (array[i] >= largest) {
+        largest = array[i];
+        index = i;
+      }
+    }
+    return index;
+  }
+
 }

Modified: labs/yay/trunk/core/src/test/resources/word2vec/test.txt
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/resources/word2vec/test.txt?rev=1732287&r1=1732286&r2=1732287&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/resources/word2vec/test.txt (original)
+++ labs/yay/trunk/core/src/test/resources/word2vec/test.txt Thu Feb 25 
11:17:56 2016
@@ -1,3 +1,8 @@
 the dog saw a cat
 the dog chased the cat
-the cat climbed a tree
\ No newline at end of file
+the cat climbed a tree
+a dog is similar to a cat
+dogs eat cats
+cats eat rats
+rats eat everything
+a rat saw something
\ No newline at end of file

Modified: labs/yay/trunk/pom.xml
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1732287&r1=1732286&r2=1732287&view=diff
==============================================================================
--- labs/yay/trunk/pom.xml (original)
+++ labs/yay/trunk/pom.xml Thu Feb 25 11:17:56 2016
@@ -80,7 +80,7 @@
         <dependency>
             <groupId>org.apache.commons</groupId>
             <artifactId>commons-math3</artifactId>
-            <version>3.5</version>
+          <version>3.6</version>
         </dependency>
     </dependencies>
   </dependencyManagement>



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to