Author: tommaso
Date: Tue Oct  6 12:01:51 2015
New Revision: 1707019

URL: http://svn.apache.org/viewvc?rev=1707019&view=rev
Log:
draft word2vec test for sgm network

Added:
    labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java
    labs/yay/trunk/core/src/test/resources/word2vec/
    labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt
Modified:
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1707019&r1=1707018&r2=1707019&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 Tue Oct  6 12:01:51 2015
@@ -39,7 +39,7 @@ import org.apache.yay.WeightLearningExce
 public class BackPropagationLearningStrategy implements 
LearningStrategy<Double, Double> {
 
   public static final double DEFAULT_THRESHOLD = 0.05;
-  private static final int MAX_ITERATIONS = 100000;
+  public static final int MAX_ITERATIONS = 100000;
   public static final double DEFAULT_ALPHA = 0.000003;
 
   private final PredictionStrategy<Double, Double> predictionStrategy;
@@ -48,31 +48,33 @@ public class BackPropagationLearningStra
   private final double alpha;
   private final double threshold;
   private final int batch;
-
+  private final int maxIterations;
 
   public BackPropagationLearningStrategy(double alpha, double threshold, 
PredictionStrategy<Double, Double> predictionStrategy,
                                          CostFunction<RealMatrix, Double, 
Double> costFunction) {
-    this(alpha, 1, threshold, predictionStrategy, costFunction);
+    this(alpha, 1, threshold, predictionStrategy, costFunction, 
MAX_ITERATIONS);
   }
 
   public BackPropagationLearningStrategy(double alpha, int batch, double 
threshold, PredictionStrategy<Double, Double> predictionStrategy,
-                                         CostFunction<RealMatrix, Double, 
Double> costFunction) {
+                                         CostFunction<RealMatrix, Double, 
Double> costFunction, int maxIterations) {
     this.predictionStrategy = predictionStrategy;
     this.costFunction = costFunction;
     this.alpha = alpha;
     this.threshold = threshold;
     this.batch = batch;
     this.derivativeUpdateFunction = new 
DefaultDerivativeUpdateFunction(predictionStrategy);
+    this.maxIterations = maxIterations;
   }
 
   public BackPropagationLearningStrategy() {
     // commonly used defaults
-    this.predictionStrategy = new FeedForwardStrategy(new TanhFunction());
+    this.predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
     this.costFunction = new LogisticRegressionCostFunction();
     this.alpha = DEFAULT_ALPHA;
     this.threshold = DEFAULT_THRESHOLD;
     this.batch = 1;
     this.derivativeUpdateFunction = new 
DefaultDerivativeUpdateFunction(predictionStrategy);
+    this.maxIterations = MAX_ITERATIONS;
   }
 
   @Override
@@ -106,7 +108,7 @@ public class BackPropagationLearningStra
 
         if (newCost > cost && batch == -1) {
           throw new RuntimeException("failed to converge at iteration " + 
iterations + " with alpha " + alpha + " : cost going from " + cost + " to " + 
newCost);
-        } else if (iterations > 1 && (cost == newCost || newCost < threshold 
|| iterations > MAX_ITERATIONS)) {
+        } else if (iterations > 1 && (cost == newCost || newCost < threshold 
|| iterations > maxIterations)) {
           System.out.println("successfully converged after " + (iterations - 
1) + " iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " 
+ newCost + " and parameters " + Arrays.toString(hypothesis.getParameters()));
           break;
         }

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1707019&r1=1707018&r2=1707019&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
 Tue Oct  6 12:01:51 2015
@@ -77,8 +77,10 @@ public class LogisticRegressionCostFunct
       Double[] predictedOutput = hypothesis.predict(input);
       Double[] sampleOutput = input.getOutput();
       for (int i = 0; i < predictedOutput.length; i++) {
-        res += sampleOutput[i] * Math.log(predictedOutput[i]) + (1d - 
sampleOutput[i])
-                * Math.log(1d - predictedOutput[i]);
+        Double so = sampleOutput[i];
+        Double po = predictedOutput[i];
+        res += so * Math.log(po) + (1d - so)
+                * Math.log(1d - po);
       }
     }
     return (-1d / trainingExamples.length) * res;

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1707019&r1=1707018&r2=1707019&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 Tue Oct  6 12:01:51 2015
@@ -73,7 +73,8 @@ public class BackPropagationLearningStra
       assertFalse("weights have not been changed", 
learntWeights[i].equals(initialWeights[i]));
     }
 
-    backPropagationLearningStrategy = new 
BackPropagationLearningStrategy(alpha, -1, threshold, predictionStrategy, 
costFunction);
+    backPropagationLearningStrategy = new 
BackPropagationLearningStrategy(alpha, -1, threshold, predictionStrategy,
+            costFunction, BackPropagationLearningStrategy.MAX_ITERATIONS);
     learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
 
@@ -147,7 +148,8 @@ public class BackPropagationLearningStra
     assertFalse(learntWeights[2].equals(initialWeights[2]));
 
     backPropagationLearningStrategy = new 
BackPropagationLearningStrategy(BackPropagationLearningStrategy.DEFAULT_ALPHA, 
-1,
-            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, new 
FeedForwardStrategy(new SigmoidFunction()), new 
LogisticRegressionCostFunction(0.5d));
+            BackPropagationLearningStrategy.DEFAULT_THRESHOLD, new 
FeedForwardStrategy(new SigmoidFunction()),
+            new LogisticRegressionCostFunction(0.5d), 
BackPropagationLearningStrategy.MAX_ITERATIONS);
 
     learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);

Added: labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java?rev=1707019&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java 
(added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java Tue 
Oct  6 12:01:51 2015
@@ -0,0 +1,286 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.Feature;
+import org.apache.yay.Input;
+import org.apache.yay.NeuralNetwork;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ConversionUtils;
+import org.apache.yay.core.utils.ExamplesFactory;
+import org.junit.Test;
+
+import static org.junit.Assert.*;
+
+/**
+ * Integration test for using Yay to implement word vectors algorithms.
+ */
+public class Word2VecTest {
+
+  @Test
+  public void testSGM() throws Exception {
+    Collection<String> sentences = getSentences();
+    assertFalse(sentences.isEmpty());
+    List<String> vocabulary = getVocabulary(sentences);
+    assertFalse(vocabulary.isEmpty());
+    Collections.sort(vocabulary);
+    Collection<String> fragments = getFragments(sentences, 4);
+    assertFalse(fragments.isEmpty());
+
+    TrainingSet<Double, Double> trainingSet = createTrainingSet(vocabulary, 
fragments);
+
+//    int n = new Random().nextInt(20);
+
+    TrainingExample<Double, Double> next = trainingSet.iterator().next();
+    int inputSize = next.getFeatures().size();
+    int outputSize = next.getOutput().length;
+    RealMatrix[] randomWeights = createRandomWeights(inputSize, inputSize, 
outputSize);
+
+    FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(new 
IdentityActivationFunction<Double>());
+    BackPropagationLearningStrategy learningStrategy = new 
BackPropagationLearningStrategy(BackPropagationLearningStrategy.
+            DEFAULT_ALPHA, -1, 
BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new 
LMSCostFunction(),
+            20);
+    NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, 
learningStrategy, predictionStrategy);
+
+    neuralNetwork.learn(trainingSet);
+
+    String word = "paper";
+//    final Double[] doubles = 
ConversionUtils.toValuesCollection(next.getFeatures()).toArray(new 
Double[next.getFeatures().size()]);
+    final Double[] doubles = hotEncode(word, vocabulary);
+//    String word = hotDecode(doubles, vocabulary);
+
+//    TrainingExample<Double, Double> input = 
ExamplesFactory.createDoubleArrayTrainingExample(new Double[outputSize], 
doubles);
+    Input<Double> input = new TrainingExample<Double, Double>() {
+      @Override
+      public ArrayList<Feature<Double>> getFeatures() {
+        ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>();
+        for (Double d : doubles) {
+          Feature<Double> f = new Feature<Double>();
+          f.setValue(d);
+          features.add(f);
+        }
+        return features;
+      }
+
+      @Override
+      public Double[] getOutput() {
+        return new Double[0];
+      }
+    };
+    Double[] predict = neuralNetwork.predict(input);
+    assertNotNull(predict);
+
+    System.out.println(Arrays.toString(predict));
+
+    Double[] wordVec1 = Arrays.copyOfRange(predict, 0, vocabulary.size());
+    assertNotNull(wordVec1);
+    Double[] wordVec2 = Arrays.copyOfRange(predict, vocabulary.size(), 2 * 
vocabulary.size());
+    assertNotNull(wordVec2);
+    Double[] wordVec3 = Arrays.copyOfRange(predict, 2 * vocabulary.size(), 3 * 
vocabulary.size());
+    assertNotNull(wordVec3);
+
+    String word1 = hotDecode(wordVec1, vocabulary);
+    assertNotNull(word1);
+    assertTrue(vocabulary.contains(word1));
+    String word2 = hotDecode(wordVec2, vocabulary);
+    assertNotNull(word2);
+    assertTrue(vocabulary.contains(word2));
+    String word3 = hotDecode(wordVec3, vocabulary);
+    assertNotNull(word3);
+    assertTrue(vocabulary.contains(word3));
+
+    System.out.println(word + " -> " + word1 + " " + word2 + " " + word3);
+  }
+
+  private String hotDecode(Double[] doubles, List<String> vocabulary) {
+    double max = -Double.MAX_VALUE;
+    int index = -1;
+    for (int i = 0; i < doubles.length; i++) {
+      Double aDouble = doubles[i];
+      if (aDouble > max) {
+        max = aDouble;
+        index = i;
+      }
+    }
+    return vocabulary.get(index);
+  }
+
+
+  private TrainingSet<Double, Double> createTrainingSet(List<String> 
vocabulary, Collection<String> fragments) {
+    Collection<TrainingExample<Double, Double>> samples = new 
LinkedList<TrainingExample<Double, Double>>();
+    for (String fragment : fragments) {
+      String[] tokens = fragment.split(" ");
+      String inputWord = null;
+      for (int i = 0; i < tokens.length; i++) {
+        List<String> outputWords = new LinkedList<String>();
+        for (int j = 0; j < tokens.length; j++) {
+          String token = tokens[i];
+          if (i == j) {
+            inputWord = token;
+          } else {
+            outputWords.add(token);
+          }
+        }
+
+        final Double[] input = hotEncode(inputWord, vocabulary);
+        final Double[] outputs = new Double[outputWords.size() * 
vocabulary.size()];
+        for (int k = 0; k < outputWords.size(); k++) {
+          Double[] doubles = hotEncode(outputWords.get(k), vocabulary);
+          for (int z = 0; z < doubles.length; z++) {
+            outputs[(k * doubles.length) + z] = doubles[z];
+          }
+        }
+        samples.add(new TrainingExample<Double, Double>() {
+          @Override
+          public Double[] getOutput() {
+            return outputs;
+          }
+
+          @Override
+          public ArrayList<Feature<Double>> getFeatures() {
+            ArrayList<Feature<Double>> features = new 
ArrayList<Feature<Double>>();
+            for (Double d : input) {
+              Feature<Double> e = new Feature<Double>();
+              e.setValue(d);
+              features.add(e);
+            }
+            return features;
+          }
+        });
+      }
+    }
+    return new TrainingSet<Double, Double>(samples);
+  }
+
+  private Double[] hotEncode(String word, List<String> vocabulary) {
+    Double[] vector = new Double[vocabulary.size()];
+    int index = Collections.binarySearch(vocabulary, word);
+    Arrays.fill(vector, 0d);
+    vector[index] = 1d;
+    return vector;
+  }
+
+  private List<String> getVocabulary(Collection<String> sentences) {
+    List<String> vocabulary = new LinkedList<String>();
+    for (String sentence : sentences) {
+      for (String token : sentence.split(" ")) {
+        if (!vocabulary.contains(token)) {
+          vocabulary.add(token);
+        }
+      }
+    }
+    return vocabulary;
+  }
+
+  private Collection<String> getFragments(Collection<String> vocabulary, int 
w) {
+    Collection<String> fragments = new LinkedList<String>();
+    for (String sentence : vocabulary) {
+      while (sentence.length() > 0) {
+        int idx = 0;
+        for (int i = 0; i < w; i++) {
+          idx = sentence.indexOf(' ', idx + 1);
+        }
+        if (idx > 0) {
+          String fragment = sentence.substring(0, idx);
+          if (fragment.split(" ").length == 4) {
+            fragments.add(fragment);
+            sentence = sentence.substring(sentence.indexOf(' ') + 1);
+          }
+        } else {
+          if (sentence.split(" ").length == 4) {
+            fragments.add(sentence);
+            sentence = "";
+          }
+        }
+      }
+    }
+    return fragments;
+  }
+
+  private Collection<String> getSentences() throws IOException {
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/sentences.txt");
+    BufferedReader bufferedReader = new BufferedReader(new 
InputStreamReader(resourceAsStream));
+    Collection<String> sentences = new LinkedList<String>();
+    String line;
+    while ((line = bufferedReader.readLine()) != null) {
+      sentences.add(line);
+    }
+    return sentences;
+  }
+
+  private RealMatrix[] createRandomWeights(int inputSize, int hiddenSize, int 
outputSize) {
+    Random r = new Random();
+    int weightsCount = 2;
+
+    RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+    for (int i = 0; i < weightsCount; i++) {
+      int rows = inputSize;
+      int cols;
+      if (i == 0) {
+        cols = hiddenSize;
+      } else {
+        cols = initialWeights[i - 1].getRowDimension();
+        if (i == weightsCount - 1) {
+          rows = outputSize;
+        }
+      }
+      double[][] d = new double[rows][cols];
+      for (int c = 0; c < cols; c++) {
+        if (i == weightsCount - 1) {
+          if (c == 0) {
+            d[0][c] = 1d;
+          } else {
+            d[0][c] = r.nextInt(100) / 101d;
+          }
+        } else {
+          d[0][c] = 0;
+        }
+      }
+
+      for (int k = 1; k < rows; k++) {
+        for (int j = 0; j < cols; j++) {
+          double val;
+          if (j == 0) {
+            val = 1d;
+          } else {
+            val = r.nextInt(100) / 101d;
+          }
+          d[k][j] = val;
+        }
+      }
+      initialWeights[i] = new Array2DRowRealMatrix(d);
+    }
+    return initialWeights;
+  }
+}

Added: labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt?rev=1707019&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt (added)
+++ labs/yay/trunk/core/src/test/resources/word2vec/sentences.txt Tue Oct  6 
12:01:51 2015
@@ -0,0 +1,8 @@
+The word2vec software of Tomas Mikolov and colleagues1 has gained a lot of 
traction lately and provides state-of-the-art word embeddings
+The learning models behind the software are described in two research papers.
+We found the description of the models in these papers to be somewhat cryptic 
and hard to follow
+While the motivations and presentation may be obvious to the neural-networks 
language-modeling crowd we had to struggle quite a bit to figure out the 
rationale behind the equations
+This note is an attempt to explain the negative sampling equation in 
“Distributed Representations of Words and Phrases and their 
Compositionality” by Tomas Mikolov, Ilya Sutskever, Kai Chen, Greg Corrado 
and Jeffrey Dean
+The departure point of the paper is the skip-gram model
+In this model we are given a corpus of words w and their contexts c
+We consider the conditional probabilities p(c|w) and given a corpus Text, the 
goal is to set the parameters θ of p(c|w; θ) so as to maximize the corpus 
probability
\ No newline at end of file



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to