Author: tommaso Date: Mon Mar 14 08:42:10 2016 New Revision: 1734888 URL: http://svn.apache.org/viewvc?rev=1734888&view=rev Log: fixed bug on mini batch
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java?rev=1734888&r1=1734887&r2=1734888&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SkipGramNetwork.java Mon Mar 14 08:42:10 2016 @@ -19,7 +19,6 @@ package org.apache.yay; import com.google.common.base.Splitter; -import org.apache.commons.lang3.ArrayUtils; import org.apache.commons.math3.distribution.UniformRealDistribution; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; @@ -29,20 +28,15 @@ import org.apache.commons.math3.linear.R import java.io.IOException; import java.nio.ByteBuffer; -import java.nio.CharBuffer; import java.nio.channels.SeekableByteChannel; -import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; -import java.util.HashSet; import java.util.LinkedList; import java.util.List; import java.util.Queue; -import java.util.Set; import java.util.concurrent.ConcurrentLinkedDeque; import java.util.regex.Pattern; @@ -236,24 +230,23 @@ public class SkipGramNetwork { long start = System.currentTimeMillis(); int c = 1; + RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length); + RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length); while (true) { - RealMatrix x = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getInputs().length); - RealMatrix y = MatrixUtils.createRealMatrix(configuration.batchSize, samples[0].getOutputs().length); int i = 0; for (int k = j * configuration.batchSize; k < j * configuration.batchSize + configuration.batchSize; k++) { Sample sample = samples[k % samples.length]; - x.setRow(i, ArrayUtils.addAll(sample.getInputs())); - y.setRow(i, ArrayUtils.addAll(sample.getOutputs())); + x.setRow(i, sample.getInputs()); + y.setRow(i, sample.getOutputs()); i++; } + j++; long time = (System.currentTimeMillis() - start) / 1000; - if (iterations % (1 + (configuration.maxIterations / 100)) == 0 || time % 300 == 0) { - if (time > 60 * c) { - c += 1; - System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); - } + if (iterations % (1 + (configuration.maxIterations / 100)) == 0 && time > 60 * c) { + c += 1; + System.out.println("cost: " + cost + ", accuracy: " + evaluate(this) + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); } RealMatrix w0t = weights[0].transpose(); @@ -384,9 +377,9 @@ public class SkipGramNetwork { System.out.println("started with cost = " + dataLoss + " + " + regLoss + " = " + newCost); } - if (Double.POSITIVE_INFINITY == newCost || newCost > cost) { + if (Double.POSITIVE_INFINITY == newCost) { throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost); - } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations || cost - newCost < configuration.threshold)) { + } else if (iterations > 1 && (newCost < configuration.threshold || iterations > configuration.maxIterations)) { cost = newCost; System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost); break; @@ -1060,11 +1053,9 @@ public class SkipGramNetwork { } } - List<String> os = new LinkedList<>(); double[] doubles = new double[window - 1]; for (int i = 0; i < doubles.length; i++) { String o = new String(outputWords.get(i)); - os.add(o); doubles[i] = (double) vocabulary.indexOf(o); } @@ -1143,91 +1134,6 @@ public class SkipGramNetwork { } } - private Queue<List<byte[]>> getFragmentsOld(Path path, int w) throws IOException { - long start = System.currentTimeMillis(); - Queue<List<byte[]>> fragments = new ConcurrentLinkedDeque<>(); - - ByteBuffer buf = ByteBuffer.allocate(100); - try (SeekableByteChannel sbc = Files.newByteChannel(path)) { - - String encoding = System.getProperty("file.encoding"); - StringBuilder previous = new StringBuilder(); - Splitter splitter = Splitter.on(Pattern.compile("[\\n\\s]")).omitEmptyStrings().trimResults(); - while (sbc.read(buf) > 0) { - buf.rewind(); - CharBuffer charBuffer = Charset.forName(encoding).decode(buf); - String string = cleanString(charBuffer.toString()); - List<String> split = splitter.splitToList(string); - int splitSize = split.size(); - if (splitSize > w) { - for (int j = 0; j < splitSize - w; j++) { - List<byte[]> fragment = new ArrayList<>(w); - String str = split.get(j); - fragment.add(previous.append(str).toString().getBytes()); - for (int i = 1; i < w; i++) { - String s = split.get(i + j); - fragment.add(s.getBytes()); - } - // TODO : this has to be used to re-use the tokens that have not been consumed in next iteration - fragments.add(fragment); - previous = new StringBuilder(); - } - previous = new StringBuilder().append(split.get(splitSize - 1)); - } else if (split.size() == w) { - previous.append(string); - } - buf.flip(); - } - } catch (IOException x) { - System.err.println("caught exception: " + x); - } finally { - buf.clear(); - } - long end = System.currentTimeMillis(); - System.out.println("fragments read in " + (end - start) / 60000 + " minutes (" + fragments.size() + ")"); - return fragments; - } - - private List<String> getVocabulary(Path path) throws IOException { - Set<String> vocabulary = new HashSet<>(); - ByteBuffer buf = ByteBuffer.allocate(100); - try (SeekableByteChannel sbc = Files.newByteChannel(path)) { - - String encoding = System.getProperty("file.encoding"); - StringBuilder previous = new StringBuilder(); - Splitter splitter = Splitter.on(Pattern.compile("[\\\n\\s]")).omitEmptyStrings().trimResults(); - while (sbc.read(buf) > 0) { - buf.rewind(); - CharBuffer charBuffer = Charset.forName(encoding).decode(buf); - String string = cleanString(charBuffer.toString()); - List<String> split = splitter.splitToList(string); - int splitSize = split.size(); - if (splitSize > 1) { - String term = previous.append(split.get(0)).toString(); - vocabulary.add(term.intern()); - for (int i = 1; i < splitSize - 1; i++) { - String term2 = split.get(i); - vocabulary.add(term2.intern()); - } - previous = new StringBuilder().append(split.get(splitSize - 1)); - } else if (split.size() == 1) { - previous.append(string); - } - buf.flip(); - } - } catch (IOException x) { - System.err.println("caught exception: " + x); - } finally { - buf.clear(); - } - List<String> list = Arrays.asList(vocabulary.toArray(new String[vocabulary.size()])); - Collections.sort(list); -// for (String iw : vocabulary) { -// System.out.println(iw +"->"+Arrays.toString(ConversionUtils.hotEncode(iw.getBytes(), list))); -// } - return list; - } - private String cleanString(String s) { return s.toLowerCase().replaceAll("\\.", " \\.").replaceAll("\\;", " \\;").replaceAll("\\,", " \\,").replaceAll("\\:", " \\:").replaceAll("\\-\\s", "").replaceAll("\\\"", " \\\""); } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org