Author: tommaso
Date: Wed Oct  5 16:12:24 2016
New Revision: 1763464

URL: http://svn.apache.org/viewvc?rev=1763464&view=rev
Log:
minor fixes, java raw char-rnn model inspired by karpathy

Added:
    labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java   (with props)
    labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java   (with props)
Modified:
    labs/yay/trunk/core/pom.xml
    labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java
    labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java

Modified: labs/yay/trunk/core/pom.xml
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1763464&r1=1763463&r2=1763464&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Wed Oct  5 16:12:24 2016
@@ -25,6 +25,9 @@
         <version>0.2-SNAPSHOT</version>
         <relativePath>../</relativePath>
     </parent>
+    <properties>
+        <nd4j.version>0.6.0</nd4j.version>
+    </properties>
     <name>Yay core</name>
     <dependencies>
         <dependency>
@@ -52,6 +55,12 @@
             <artifactId>guava</artifactId>
             <version>18.0</version>
         </dependency>
+        <dependency>
+            <groupId>org.nd4j</groupId>
+            <artifactId>nd4j-native-platform</artifactId>
+            <version>${nd4j.version}</version>
+        </dependency>
+
     </dependencies>
     <build>
         <plugins>

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java?rev=1763464&r1=1763463&r2=1763464&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/MultiLayerNetwork.java Wed 
Oct  5 16:12:24 2016
@@ -110,9 +110,7 @@ public class MultiLayerNetwork {
     while (true) {
       if (iterations % (1 + (configuration.maxIterations / 100)) == 0) {
         long time = (System.currentTimeMillis() - start) / 1000;
-//        if (time > 60) {
           System.out.println("cost is " + cost + " after " + iterations + " 
iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " 
ips)");
-//        }
       }
       // current training example
       Sample sample = samples[iterations % samples.length];

Added: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java?rev=1763464&view=auto
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java (added)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java Wed Oct  5 
16:12:24 2016
@@ -0,0 +1,277 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.distribution.EnumeratedDistribution;
+import org.apache.commons.math3.util.Pair;
+import org.nd4j.linalg.api.iter.NdIndexIterator;
+import org.nd4j.linalg.api.ndarray.INDArray;
+import org.nd4j.linalg.api.ops.impl.transforms.SetRange;
+import org.nd4j.linalg.factory.Nd4j;
+import org.nd4j.linalg.ops.transforms.Transforms;
+
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * A min char-level vanilla RNN model, based on Andrej Karpathy's python code.
+ * See also:
+ *
+ * @see <a href="http://karpathy.github.io/2015/05/21/rnn-effectiveness";>The 
Unreasonable Effectiveness of Recurrent Neural Networks</a>
+ * @see <a 
href="https://gist.github.com/karpathy/d4dee566867f8291f086";>Minimal 
character-level language model with a Vanilla Recurrent Neural Network, in 
Python/numpy</a>
+ */
+public class RNN {
+
+  public void learn(String text) {
+
+    char[] textChars = text.toCharArray();
+    List<Character> data = new LinkedList<>();
+    for (char c : textChars) {
+      data.add(c);
+    }
+    Set<Character> chars = new HashSet<>(data);
+    int vocabSize = chars.size();
+    System.out.printf("data has %d characters, %d unique.", data.size(), 
vocabSize);
+    Map<Character, Integer> charToIx = new HashMap<>();
+    Map<Integer, Character> ixToChar = new HashMap<>();
+    int i = 0;
+    for (Character c : chars) {
+      charToIx.put(c, i);
+      ixToChar.put(i, c);
+      i++;
+    }
+
+    // hyperparameters
+    int hiddenSize = 40; // size of hidden layer of neurons
+    int seqLength = 10; // no. of steps to unroll the RNN for
+    float learningRate = 1e-2f;
+
+    // model parameters
+    INDArray wxh = Nd4j.randn(hiddenSize, vocabSize).mul(0.001); // input to 
hidden
+    INDArray whh = Nd4j.randn(hiddenSize, hiddenSize).mul(0.001); // hidden to 
hidden
+    INDArray why = Nd4j.randn(vocabSize, hiddenSize).mul(0.001); // hidden to 
output
+    INDArray bh = Nd4j.zeros(hiddenSize, 1); // hidden bias
+    INDArray by = Nd4j.zeros(vocabSize, 1); // output bias
+
+    int n = 0;
+    int p = 0;
+
+    // memory variables for Adagrad
+    INDArray mWxh = Nd4j.zerosLike(wxh);
+    INDArray mWhh = Nd4j.zerosLike(whh);
+    INDArray mWhy = Nd4j.zerosLike(why);
+
+    INDArray mbh = Nd4j.zerosLike(bh);
+    INDArray mby = Nd4j.zerosLike(by);
+
+    // loss at iteration 0
+    double smoothLoss = -Math.log(1.0 / vocabSize) * seqLength;
+
+    INDArray hPrev = null;
+    while (true) {
+      // prepare inputs (we're sweeping from left to right in steps seqLength 
long)
+      if (p + seqLength + 1 >= data.size() || n == 0) {
+        hPrev = Nd4j.zeros(hiddenSize, 1); // reset RNN memory
+        p = 0; // go from start of data
+      }
+
+      INDArray inputs = Nd4j.create(seqLength);
+      int c = 0;
+      for (Character ch : data.subList(p, p + seqLength)) {
+        Integer ix = charToIx.get(ch);
+        inputs.putScalar(c, ix);
+        c++;
+      }
+      INDArray targets = Nd4j.create(seqLength);
+      c = 0;
+      for (Character ch : data.subList(p + 1, p + seqLength + 1)) {
+        Integer ix = charToIx.get(ch);
+        targets.putScalar(c, ix);
+        c++;
+      }
+
+      // sample from the model now and then
+      if (n % 1000 == 0) {
+        sample(vocabSize, ixToChar, wxh, whh, why, bh, by, hPrev, inputs);
+      }
+
+      INDArray dWxh = Nd4j.zerosLike(wxh);
+      INDArray dWhh = Nd4j.zerosLike(whh);
+      INDArray dWhy = Nd4j.zerosLike(why);
+
+      INDArray dbh = Nd4j.zerosLike(bh);
+      INDArray dby = Nd4j.zerosLike(by);
+
+      // forward seqLength characters through the net and fetch gradient
+      double loss = lossFun(vocabSize, wxh, whh, why, bh, by, hPrev, inputs, 
targets, dWxh, dWhh, dWhy, dbh, dby);
+      smoothLoss = smoothLoss * 0.99 + loss * 0.001;
+      if (n % 100 == 0) {
+        System.out.printf("iter %d, loss: %f\n", n, smoothLoss); // print 
progress
+      }
+
+      // perform parameter update with Adagrad
+      mWxh.addi(dWxh.mul(dWxh));
+      wxh.subi((dWxh.mul(learningRate)).div(Transforms.sqrt(mWxh.addi(1e-8))));
+
+      mWhh.addi(dWhh.mul(dWhh));
+      whh.subi(dWhh.mul(learningRate).div(Transforms.sqrt(mWhh.addi(1e-8))));
+
+      mWhy.addi(dWhy.mul(dWhy));
+      why.subi(dWhy.mul(learningRate).div(Transforms.sqrt(mWhy.addi(1e-8))));
+
+      mbh.addi(dbh.mul(dbh));
+      bh.subi(dbh.mul(learningRate).div(Transforms.sqrt(mbh.addi(1e-8))));
+
+      mby.addi(dby.mul(dby));
+      by.subi(dby.mul(learningRate).div(Transforms.sqrt(mby.addi(1e-8))));
+
+      p += seqLength; // move data pointer
+      n++; // iteration counter
+    }
+  }
+
+  /**
+   * inputs, targets are both list of integers
+   * hprev is Hx1 array of initial hidden state
+   * returns the loss, gradients on model parameters and last hidden state
+   */
+  private double lossFun(int vocabSize, INDArray wxh, INDArray whh, INDArray 
why, INDArray bh, INDArray by, INDArray hPrev,
+                         INDArray inputs, INDArray targets, INDArray dWxh, 
INDArray dWhh, INDArray dWhy, INDArray dbh,
+                         INDArray dby) {
+
+    INDArray xs = Nd4j.zeros(inputs.length(), vocabSize);
+    INDArray hs = null;
+    INDArray ys = null;
+    INDArray ps = null;
+
+    INDArray hs1 = Nd4j.create(hPrev.shape());
+    Nd4j.copy(hPrev, hs1);
+
+    double loss = 0;
+
+    // forward pass
+    for (int t = 0; t < inputs.length(); t++) {
+      int tIndex = inputs.getScalar(t).getInt(0);
+      xs.putScalar(t, tIndex, 1); // encode in 1-of-k representation
+      INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
+      INDArray hst = 
Transforms.tanh((wxh.mmul(xs.getRow(t).transpose())).add((whh.mmul(hsRow)).add(bh)));
 // hidden state
+      if (hs == null) {
+        hs = init(inputs.length(), hst);
+      }
+      hs.putRow(t, hst);
+
+      INDArray yst = (why.mmul(hst)).add(by); // unnormalized log 
probabilities for next chars
+      if (ys == null) {
+        ys = init(inputs.length(), yst);
+      }
+      ys.putRow(t, yst);
+      INDArray exp = Transforms.exp(yst);
+      Number sumExp = exp.sumNumber();
+      INDArray pst = exp.div(sumExp); // probabilities for next chars
+      if (ps == null) {
+        ps = init(inputs.length(), pst);
+      }
+      ps.putRow(t, pst);
+      loss += -Transforms.log(ps.getRow(t).getRow(targets.getInt(t)), 
true).sumNumber().doubleValue(); // softmax (cross-entropy loss)
+    }
+
+    // backward pass: compute gradients going backwards
+    INDArray dhNext = Nd4j.zerosLike(hs.getRow(0));
+    for (int t = inputs.length() - 1; t >= 0; t--) {
+      INDArray dy = ps.getRow(t).dup();
+      dy.putRow(targets.getInt(t), dy.getRow(targets.getInt(t)).sub(1)); // 
backprop into y
+      INDArray hst = hs.getRow(t);
+      dWhy.addi(dy.mmul(hst.transpose()));
+      dby.addi(dy);
+      INDArray dh = why.transpose().mmul(dy).add(dhNext); // backprop into h
+      INDArray dhraw = (Nd4j.ones(hst.shape()).sub(hst).mul(hst)).mul(dh); // 
backprop through tanh nonlinearity
+      dbh.addi(dhraw);
+      dWxh.addi(dhraw.mmul(xs.getRow(t)));
+      INDArray hsRow = t == 0 ? hs1 : hs.getRow(t - 1);
+      dWhh.addi(dhraw.mmul(hsRow.transpose()));
+      dhNext = whh.transpose().mmul(dhraw);
+
+    }
+    // clip exploding gradients
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWxh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWhh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dWhy, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dbh, -5, 5));
+    Nd4j.getExecutioner().execAndReturn(new SetRange(dby, -5, 5));
+
+    return loss;
+  }
+
+  private INDArray init(int t, INDArray ast) {
+    INDArray as;
+    int[] aShape = ast.shape();
+    int[] shape = new int[1 + aShape.length];
+    shape[0] = t;
+    System.arraycopy(aShape, 0, shape, 1, aShape.length);
+    as = Nd4j.create(shape);
+    return as;
+  }
+
+  /**
+   * sample a sequence of integers from the model, h is memory state, seed_ix 
is seed letter for first time step
+   */
+  private void sample(int vocabSize, Map<Integer, Character> ixToChar, 
INDArray wxh, INDArray whh, INDArray why,
+                      INDArray bh, INDArray by, INDArray hPrev, INDArray 
inputs) {
+
+    INDArray x = Nd4j.zeros(vocabSize, 1);
+    int seedIx = inputs.getInt(0);
+    x.putScalar(seedIx, 1);
+    int sampleSize = 200;
+    INDArray ixes = Nd4j.create(sampleSize);
+
+    for (int t = 0; t < sampleSize; t++) {
+      INDArray h = 
Transforms.tanh((wxh.mmul(x)).add((whh.mmul(hPrev)).add(bh)));
+      INDArray y = (why.mmul(h)).add(by);
+      INDArray exp = Transforms.exp(y);
+      INDArray pm = exp.div(Nd4j.sum(exp)).ravel();
+
+      List<Pair<Integer, Double>> d = new LinkedList<>();
+      for (int pi = 0; pi < vocabSize; pi++) {
+        d.add(new Pair<>(pi, pm.getDouble(0, pi)));
+      }
+      EnumeratedDistribution<Integer> distribution = new 
EnumeratedDistribution<>(d);
+
+      int ix = distribution.sample();
+
+      x = Nd4j.zeros(vocabSize, 1);
+      x.putScalar(ix, 1);
+      ixes.putScalar(t, ix);
+    }
+
+    String txt = "";
+
+
+    NdIndexIterator ndIndexIterator = new NdIndexIterator(ixes.shape());
+    while (ndIndexIterator.hasNext()) {
+      int[] next = ndIndexIterator.next();
+      txt += ixToChar.get(ixes.getInt(next));
+    }
+    System.out.printf("\n---\n %s \n----\n", txt);
+  }
+
+}

Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/RNN.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1763464&r1=1763463&r2=1763464&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Wed 
Oct  5 16:12:24 2016
@@ -62,8 +62,8 @@ public class SkipGramNetwork {
    * the first row of weighs[0] matrix holds the weights of each neuron in 
the first neuron of the second layer,
    * the second row of weighs[0] holds the weights of each neuron in the 
second neuron of the second layer, etc.
    */
-  private RealMatrix[] weights;
-  private RealMatrix[] biases;
+  private final RealMatrix[] weights;
+  private final RealMatrix[] biases;
   private Sample[] samples;
 
 
@@ -184,7 +184,7 @@ public class SkipGramNetwork {
    * @return the output
    * @throws Exception
    */
-  public double[] predictOutput(double[] input) throws Exception {
+  private double[] predictOutput(double[] input) throws Exception {
 
     RealMatrix hidden = 
rectifierFunction.applyMatrix(MatrixUtils.createRowRealMatrix(input).multiply(weights[0].transpose()).
             add(biases[0]));
@@ -214,7 +214,7 @@ public class SkipGramNetwork {
    * @return the final cost with the updated weights
    * @throws Exception if BGD fails to converge or any numerical error happens
    */
-  public double learnWeights(Sample... samples) throws Exception {
+  private double learnWeights(Sample... samples) throws Exception {
 
     int iterations = 0;
 
@@ -250,7 +250,7 @@ public class SkipGramNetwork {
       }
 
       RealMatrix w0t = weights[0].transpose();
-      final RealMatrix w1t = weights[1].transpose();
+      RealMatrix w1t = weights[1].transpose();
 
       RealMatrix hidden = rectifierFunction.applyMatrix(x.multiply(w0t));
       hidden.walkInOptimizedOrder(new RealMatrixChangingVisitor() {
@@ -912,23 +912,23 @@ public class SkipGramNetwork {
 
   private static class Configuration {
     // internal parameters
-    protected int outputs;
-    protected int inputs;
+    int outputs;
+    int inputs;
 
-    protected List<String> vocabulary;
+    List<String> vocabulary;
 
     // user controlled parameters
-    protected Path path;
-    protected int maxIterations;
-    protected double alpha = 0.5d;
-    protected double mu = 0.9d;
-    protected double regularizationLambda = 0.03;
-    protected double threshold = 0.0000000000004d;
-    protected int vectorSize;
-    protected int window;
-    protected boolean useMomentum;
-    protected boolean useNesterovMomentum;
-    protected int batchSize;
+    Path path;
+    int maxIterations;
+    double alpha = 0.5d;
+    double mu = 0.9d;
+    double regularizationLambda = 0.03;
+    double threshold = 0.0000000000004d;
+    int vectorSize;
+    int window;
+    boolean useMomentum;
+    boolean useNesterovMomentum;
+    int batchSize;
   }
 
   public static class Builder {
@@ -1039,7 +1039,7 @@ public class SkipGramNetwork {
       return vocabulary;
     }
 
-    private Collection<HotEncodedSample> createTrainingSet(final List<String> 
vocabulary, Queue<List<byte[]>> fragments, int window) throws IOException {
+    private Collection<HotEncodedSample> createTrainingSet(final List<String> 
vocabulary, Queue<List<byte[]>> fragments, int window) throws Exception {
       long start = System.currentTimeMillis();
       Collection<HotEncodedSample> samples = new LinkedList<>();
       List<byte[]> fragment;
@@ -1063,8 +1063,9 @@ public class SkipGramNetwork {
         String x = new String(inputWord);
         inputs[0] = (double) vocabulary.indexOf(x);
 
-        samples.add(new HotEncodedSample(inputs, doubles, vocabulary.size()));
-//        System.err.println("added: " + x + " -> " + 
Arrays.toString(os.toArray()));
+        HotEncodedSample hotEncodedSample = new HotEncodedSample(inputs, 
doubles, vocabulary.size());
+        samples.add(hotEncodedSample);
+//        System.err.println("added: " + x + " -> " + hotEncodedSample);
       }
 
       long end = System.currentTimeMillis();

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java?rev=1763464&r1=1763463&r2=1763464&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java 
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/MultiLayerNetworkTest.java 
Wed Oct  5 16:12:24 2016
@@ -36,7 +36,7 @@ public class MultiLayerNetworkTest {
   @Test
   public void testLearnAndPredict() throws Exception {
     MultiLayerNetwork.Configuration configuration = new 
MultiLayerNetwork.Configuration();
-    configuration.alpha = 0.0000001d;
+    configuration.alpha = 0.000000001d;
     configuration.layers = new int[]{3, 4, 1};
     configuration.maxIterations = 1000000;
     configuration.threshold = 0.00000004d;

Added: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java?rev=1763464&view=auto
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java (added)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java Wed Oct  5 
16:12:24 2016
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.io.IOUtils;
+import org.junit.Test;
+
+import java.io.InputStream;
+
+/**
+ * Tests for {@link RNN}
+ */
+public class RNNTest {
+
+  @Test
+  public void test() throws Exception {
+    InputStream resourceAsStream = 
getClass().getResourceAsStream("/word2vec/abstracts.txt");
+    String text = IOUtils.toString(resourceAsStream);
+    RNN n = new RNN();
+    n.learn(text);
+  }
+
+}
\ No newline at end of file

Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/RNNTest.java
------------------------------------------------------------------------------
    svn:eol-style = native

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java?rev=1763464&r1=1763463&r2=1763464&view=diff
==============================================================================
--- labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java 
(original)
+++ labs/yay/trunk/core/src/test/java/org/apache/yay/SkipGramNetworkTest.java 
Wed Oct  5 16:12:24 2016
@@ -47,12 +47,33 @@ public class SkipGramNetworkTest {
             withWindow(3).
             fromTextAt(path).
             withDimension(10).
-            withAlpha(0.01).
+            withAlpha(0.1).
             withLambda(0.0001).
             useNesterovMomentum(true).
             withMu(0.9).
             withMaxIterations(30000).
-            withBatchSize(10).
+            withBatchSize(1).
+            build();
+    RealMatrix wv = network.getWeights()[0];
+    List<String> vocabulary = network.getVocabulary();
+    serialize(vocabulary, wv);
+    System.err.println("accuracy: " + SkipGramNetwork.evaluate(network));
+    measure(vocabulary, wv);
+  }
+
+  @Test
+  public void testWordVectorsLearningOnBigText() throws Exception {
+    Path path = 
Paths.get(getClass().getResource("/word2vec/big.txt").getFile());
+    SkipGramNetwork network = SkipGramNetwork.newModel().
+            withWindow(3).
+            fromTextAt(path).
+            withDimension(2).
+            withAlpha(0.1).
+            withLambda(0.0001).
+            useNesterovMomentum(true).
+            withMu(0.9).
+            withMaxIterations(1000000).
+            withBatchSize(1).
             build();
     RealMatrix wv = network.getWeights()[0];
     List<String> vocabulary = network.getVocabulary();



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to