Author: tommaso Date: Mon Nov 23 08:32:44 2015 New Revision: 1715735 URL: http://svn.apache.org/viewvc?rev=1715735&view=rev Log: various performance improvements
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingSet.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingSet.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingSet.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingSet.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingSet.java Mon Nov 23 08:32:44 2015 @@ -40,4 +40,8 @@ public class TrainingSet<F, O> implement public int size() { return samples.size(); } + + public TrainingExample[] toArray() { + return samples.toArray(new TrainingExample[size()]); + } } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Mon Nov 23 08:32:44 2015 @@ -87,24 +87,26 @@ public class BackPropagationLearningStra Iterator<TrainingExample<Double, Double>> iterator = trainingExamples.iterator(); double cost = Double.MAX_VALUE; + long start = System.currentTimeMillis(); while (true) { - System.err.println(iterations); - TrainingSet<Double, Double> samples; - if (batch == -1) { - samples = trainingExamples; - } else { - TrainingExample<Double, Double>[] miniBatch = new TrainingExample[batch]; + if (iterations > 0 && iterations % (maxIterations / 1000d) == 0) { + long time = (System.currentTimeMillis() - start) / 1000; + if (time / 60 > 2) { + System.out.println(iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); + } + } + TrainingExample<Double, Double>[] miniBatch = batch > 0 ? new TrainingExample[batch] : trainingExamples.toArray(); + if (batch > 0) { for (int i = 0; i < batch; i++) { if (!iterator.hasNext()) { iterator = trainingExamples.iterator(); } miniBatch[i] = iterator.next(); } - samples = new TrainingSet<Double, Double>(Arrays.asList(miniBatch)); } // calculate cost - double newCost = costFunction.calculateAggregatedCost(samples, neuralNetwork); + double newCost = costFunction.calculateCost(neuralNetwork, miniBatch); if (Double.POSITIVE_INFINITY == newCost || newCost > cost && batch == -1) { throw new RuntimeException("failed to converge at iteration " + iterations + " with alpha " + alpha + " : cost going from " + cost + " to " + newCost); @@ -118,8 +120,10 @@ public class BackPropagationLearningStra // update registered cost cost = newCost; + TrainingSet<Double, Double> trainingSet = batch < 0 ? trainingExamples : new TrainingSet<>(Arrays.asList(miniBatch)); + // calculate the derivatives to update the parameters - RealMatrix[] derivatives = derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples); + RealMatrix[] derivatives = derivativeUpdateFunction.updateParameters(weightsMatrixSet, trainingSet); // calculate the updated parameters updatedWeights = updateWeights(updatedWeights, derivatives, alpha); @@ -141,8 +145,8 @@ public class BackPropagationLearningStra for (int l = 0; l < weightsMatrixSet.length; l++) { RealMatrix realMatrix = weightsMatrixSet[l].copy(); - final double[][] data = derivatives[l].getData(); - RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor(){ + final int finalL = l; + RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() { @Override public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { @@ -152,7 +156,7 @@ public class BackPropagationLearningStra @Override public double visit(int row, int column, double value) { if (!(row == 0 && value == 0d) && !(column == 0 && value == 1d)) { - return value - alpha * data[row][column]; + return value - alpha * derivatives[finalL].getEntry(row, column); } else { return value; } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/CrossEntropyCostFunction.java Mon Nov 23 08:32:44 2015 @@ -42,8 +42,9 @@ public class CrossEntropyCostFunction im return calculateCost(hypothesis, samples); } - private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, - TrainingExample<Double, Double>... trainingExamples) throws PredictionException { + @SafeVarargs + private final Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, + TrainingExample<Double, Double>... trainingExamples) throws PredictionException { Double res = 0d; for (TrainingExample<Double, Double> input : trainingExamples) { @@ -58,8 +59,9 @@ public class CrossEntropyCostFunction im return res; } + @SafeVarargs @Override - public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { + public final Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { return calculateErrorTerm(hypothesis, trainingExamples); } } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Mon Nov 23 08:32:44 2015 @@ -18,17 +18,20 @@ */ package org.apache.yay.core; -import java.util.Arrays; - import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.yay.DerivativeUpdateFunction; +import org.apache.yay.Feature; import org.apache.yay.PredictionStrategy; import org.apache.yay.TrainingExample; import org.apache.yay.TrainingSet; import org.apache.yay.core.utils.ConversionUtils; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; + /** * Default derivatives update function */ @@ -51,7 +54,10 @@ class DefaultDerivativeUpdateFunction im for (TrainingExample<Double, Double> trainingExample : trainingExamples) { try { // get activations from feed forward propagation - RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet); + ArrayList<Feature<Double>> features = trainingExample.getFeatures(); + Collection<Double> input = ConversionUtils.toValuesCollection(features); + + RealVector[] activations = predictionStrategy.debugOutput(input, weightsMatrixSet); // calculate output error (corresponding to the last delta^l) RealVector nextLayerDelta = calculateOutputError(trainingExample, activations); @@ -68,12 +74,12 @@ class DefaultDerivativeUpdateFunction im } RealVector[] newActivations = new RealVector[activations.length]; - newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); + newActivations[0] = ConversionUtils.toRealVector(input); System.arraycopy(activations, 0, newActivations, 1, activations.length - 1); - // update triangle (big delta matrix) - updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet); + // update triangle (big delta matrix) + updateTriangle(triangle, newActivations, deltaVectors); } catch (Exception e) { throw new RuntimeException("error during derivatives calculation", e); } @@ -88,8 +94,8 @@ class DefaultDerivativeUpdateFunction im return derivatives; } - private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) { - for (int l = weightsMatrixSet.length - 1; l >= 0; l--) { + private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors) { + for (int l = triangle.length - 1; l >= 0; l--) { RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]); if (triangle[l] == null) { triangle[l] = realMatrix; Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Mon Nov 23 08:32:44 2015 @@ -44,7 +44,7 @@ public class FeedForwardStrategy impleme private final Map<Integer, ActivationFunction<Double>> activationFunctionMap; public FeedForwardStrategy(ActivationFunction<Double> activationFunction) { - this.activationFunctionMap = new HashMap<Integer, ActivationFunction<Double>>(); + this.activationFunctionMap = new HashMap<>(); this.activationFunctionMap.put(0, activationFunction); } @@ -69,15 +69,15 @@ public class FeedForwardStrategy impleme Double[] doubles = input.toArray(new Double[input.size()]); RealMatrix x = MatrixUtils.createRowRealMatrix(Stream.of(doubles).mapToDouble(Double::doubleValue).toArray()); for (int w = 0; w < realMatrixSet.length; w++) { - final RealMatrix currentWeightsMatrix = realMatrixSet[w]; // compute matrix multiplication - x = x.multiply(currentWeightsMatrix.transpose()); + x = x.multiply(realMatrixSet[w].transpose()); - // apply the activation function to each element in the matrix - final RealMatrix cm = x.getRowMatrix(0); + // get activation function for w-th layer int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0; - final ActivationFunction<Double> af = activationFunctionMap.get(idx); - x = af.applyMatrix(cm); + + // apply the activation function to each element in the matrix + x = activationFunctionMap.get(idx).applyMatrix(x.getRowMatrix(0)); + debugOutput[w] = x.getRowVector(0); } return debugOutput; Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java Mon Nov 23 08:32:44 2015 @@ -41,8 +41,9 @@ public class LMSCostFunction implements return calculateCost(hypothesis, samples); } + @SafeVarargs @Override - public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { + public final Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { Double cost = 0d; for (TrainingExample<Double, Double> example : trainingExamples) { Double[] actualOutput = example.getOutput(); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java Mon Nov 23 08:32:44 2015 @@ -19,6 +19,7 @@ package org.apache.yay.core; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealMatrixPreservingVisitor; import org.apache.yay.Hypothesis; import org.apache.yay.NeuralNetworkCostFunction; import org.apache.yay.PredictionException; @@ -53,23 +54,46 @@ public class LogisticRegressionCostFunct return calculateCost(hypothesis, samples); } - private Double calculateRegularizationTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, - TrainingExample<Double, Double>... trainingExamples) { + @SafeVarargs + private final Double calculateRegularizationTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, + TrainingExample<Double, Double>... trainingExamples) { Double res = 1d; for (RealMatrix layerMatrix : hypothesis.getParameters()) { - for (int i = 0; i < layerMatrix.getColumnDimension(); i++) { - double[] column = layerMatrix.getColumn(i); - // starting from 1 to avoid including the bias unit in regularization - for (int j = 1; j < column.length; j++) { - res += Math.pow(column[j], 2d); + res += layerMatrix.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { + + private double res = 0d; + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + } - } + + @Override + public void visit(int row, int column, double value) { + if (column > 0) { + res += Math.pow(value, 2d); + } + } + + @Override + public double end() { + return res; + } + }); +// for (int i = 0; i < layerMatrix.getColumnDimension(); i++) { +// double[] column = layerMatrix.getColumn(i); +// // starting from 1 to avoid including the bias unit in regularization +// for (int j = 1; j < column.length; j++) { +// res += Math.pow(column[j], 2d); +// } +// } } return (lambda / (2d * trainingExamples.length)) * res; } - private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, - TrainingExample<Double, Double>... trainingExamples) throws PredictionException { + @SafeVarargs + private final Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, + TrainingExample<Double, Double>... trainingExamples) throws PredictionException { Double res = 0d; for (TrainingExample<Double, Double> input : trainingExamples) { @@ -85,8 +109,9 @@ public class LogisticRegressionCostFunct return (-1d / trainingExamples.length) * res; } + @SafeVarargs @Override - public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { + public final Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { Double errorTerm = calculateErrorTerm(hypothesis, trainingExamples); Double regularizationTerm = calculateRegularizationTerm(hypothesis, trainingExamples); return errorTerm + regularizationTerm; Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java Mon Nov 23 08:32:44 2015 @@ -20,7 +20,7 @@ package org.apache.yay.core; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrixChangingVisitor; -import org.apache.commons.math3.linear.RealVector; +import org.apache.commons.math3.linear.RealMatrixPreservingVisitor; import org.apache.yay.ActivationFunction; /** @@ -58,15 +58,24 @@ public class SoftmaxActivationFunction i } private double expDen(RealMatrix matrix) { - double d = 0d; - for (int i = 0; i < matrix.getRowDimension(); i++) { - RealVector currentRow = matrix.getRowVector(i); - for (int j = 0; j < matrix.getColumnDimension(); j++) { - double entry = currentRow.getEntry(j); - d += Math.exp(entry); + return matrix.walkInOptimizedOrder(new RealMatrixPreservingVisitor() { + private double d1 = 0d; + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public void visit(int row, int column, double value) { + d1 += Math.exp(value); + } + + @Override + public double end() { + return d1; } - } - return d; + }); } } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/neuron/LinearNeuron.java Mon Nov 23 08:32:44 2015 @@ -31,7 +31,7 @@ class LinearNeuron extends Neuron<Double private final double bias; LinearNeuron(double bias, double... weights) { - super(new IdentityActivationFunction<Double>(), weights); + super(new IdentityActivationFunction<>(), weights); this.bias = bias; this.weights = weights; } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ConversionUtils.java Mon Nov 23 08:32:44 2015 @@ -31,14 +31,15 @@ import java.util.Collection; import java.util.Collections; import java.util.List; import java.util.WeakHashMap; +import java.util.stream.Collectors; /** * Temporary class for conversion between model objects and commons-math matrices/vectors */ public class ConversionUtils { - private static final WeakHashMap<String, Double[]> wordCache = new WeakHashMap<String, Double[]>(); - private static final WeakHashMap<String, Integer> vocabularyCache = new WeakHashMap<String, Integer>(); + private static final WeakHashMap<String, Double[]> wordCache = new WeakHashMap<>(); + private static final WeakHashMap<String, Integer> vocabularyCache = new WeakHashMap<>(); /** * Converts a set of examples to a matrix of inputs with features @@ -94,12 +95,7 @@ public class ConversionUtils { * @return a vector of Doubles */ public static <T> Collection<T> toValuesCollection(Collection<Feature<T>> featureVector) { - // TODO : remove this and change APIs in a way that doesn't force to go through this ugly loop - Collection<T> resultVector = new ArrayList<T>(); - for (Feature<T> feature : featureVector) { - resultVector.add(feature.getValue()); - } - return resultVector; + return featureVector.stream().map(Feature::getValue).collect(Collectors.toCollection(ArrayList::new)); } /** Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java Mon Nov 23 08:32:44 2015 @@ -18,12 +18,12 @@ */ package org.apache.yay.core.utils; -import java.util.ArrayList; - import org.apache.yay.Feature; import org.apache.yay.Input; import org.apache.yay.TrainingExample; +import java.util.ArrayList; + /** * Factory class for {@link org.apache.yay.Input}s and {@link TrainingExample}s. */ @@ -60,21 +60,16 @@ public class ExamplesFactory { } public static Input<Double> createDoubleInput(final Double... featuresValues) { - return new Input<Double>() { - @Override - public ArrayList<Feature<Double>> getFeatures() { - return doublesToFeatureVector(featuresValues); - } - }; + return () -> doublesToFeatureVector(featuresValues); } private static ArrayList<Feature<Double>> doublesToFeatureVector(Double[] featuresValues) { - ArrayList<Feature<Double>> features = new ArrayList<Feature<Double>>(); - Feature<Double> byasFeature = new Feature<Double>(); + ArrayList<Feature<Double>> features = new ArrayList<>(); + Feature<Double> byasFeature = new Feature<>(); byasFeature.setValue(1d); features.add(byasFeature); for (Double d : featuresValues) { - Feature<Double> feature = new Feature<Double>(); + Feature<Double> feature = new Feature<>(); feature.setValue(d); features.add(feature); } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java?rev=1715735&r1=1715734&r2=1715735&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/WordVectorsTest.java Mon Nov 23 08:32:44 2015 @@ -96,6 +96,7 @@ public class WordVectorsTest { int inputSize = next.getFeatures().size(); int outputSize = next.getOutput().length; + int hiddenSize = 30; System.out.println("initializing neural network"); RealMatrix[] randomWeights = createRandomWeights(inputSize, hiddenSize, outputSize); @@ -104,7 +105,7 @@ public class WordVectorsTest { activationFunctions.put(0, new IdentityActivationFunction<Double>()); activationFunctions.put(1, new SoftmaxActivationFunction()); FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions); - BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.01d, 1, + BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.001d, 1, BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new CrossEntropyCostFunction(), trainingSet.size()); NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy); @@ -189,9 +190,6 @@ public class WordVectorsTest { double maxSimilarity = -Double.MAX_VALUE; double maxSimilarity1 = -Double.MAX_VALUE; double maxSimilarity2 = -Double.MAX_VALUE; - double[] bestVector = null; - double[] bestVector1 = null; - double[] bestVector2 = null; int j0 = -1; int j1 = -1; int j2 = -1; @@ -202,32 +200,26 @@ public class WordVectorsTest { double similarity = 1 / distanceMeasure.compute(subjectVector, vector); if (similarity > maxSimilarity) { maxSimilarity2 = maxSimilarity1; - bestVector2 = bestVector1; j2 = j1; maxSimilarity1 = maxSimilarity; - bestVector1 = bestVector; j1 = j0; maxSimilarity = similarity; - bestVector = vector; j0 = j; } else if (similarity > maxSimilarity1) { maxSimilarity2 = maxSimilarity1; - bestVector2 = bestVector1; j2 = j1; maxSimilarity1 = similarity; - bestVector1 = vector; j1 = j; } else if (similarity > maxSimilarity2) { maxSimilarity2 = similarity; - bestVector2 = vector; j2 = j; } } } - if (bestVector != null && i > 0 && j0 > 0 && j1 > 0 && j2 > 0) { + if (i > 0 && j0 > 0 && j1 > 0 && j2 > 0) { System.out.println(vocabulary.get(i - 1) + " is similar to " + vocabulary.get(j0 - 1) + ", " + vocabulary.get(j1 - 1) + ", " @@ -244,8 +236,8 @@ public class WordVectorsTest { List<byte[]> fragment; while ((fragment = fragments.poll()) != null) { byte[] inputWord = null; + List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1); for (int i = 0; i < fragment.size(); i++) { - List<byte[]> outputWords = new ArrayList<>(fragment.size() - 1); for (int j = 0; j < fragment.size(); j++) { byte[] token = fragment.get(i); if (i == j) { @@ -254,28 +246,28 @@ public class WordVectorsTest { outputWords.add(token); } } - - final byte[] finalInputWord = inputWord; - samples.add(new TrainingExample<Double, Double>() { - @Override - public Double[] getOutput() { - Double[] doubles = new Double[window - 1]; - for (int i = 0; i < doubles.length; i++) { - doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i))); - } - return doubles; + } + final byte[] finalInputWord = inputWord; + samples.add(new TrainingExample<Double, Double>() { + @Override + public Double[] getOutput() { + Double[] doubles = new Double[window - 1]; + for (int i = 0; i < doubles.length; i++) { + doubles[i] = (double) vocabulary.indexOf(new String(outputWords.get(i))); } + return doubles; + } + + @Override + public ArrayList<Feature<Double>> getFeatures() { + ArrayList<Feature<Double>> features = new ArrayList<>(); + Feature<Double> e = new Feature<>(); + e.setValue((double) vocabulary.indexOf(new String(finalInputWord))); + features.add(e); + return features; + } + }); - @Override - public ArrayList<Feature<Double>> getFeatures() { - ArrayList<Feature<Double>> features = new ArrayList<>(); - Feature<Double> e = new Feature<>(); - e.setValue((double) vocabulary.indexOf(new String(finalInputWord))); - features.add(e); - return features; - } - }); - } } EncodedTrainingSet trainingSet = new EncodedTrainingSet(samples, vocabulary, window); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org