Author: tommaso Date: Thu Oct 8 15:03:49 2015 New Revision: 1707564 URL: http://svn.apache.org/viewvc?rev=1707564&view=rev Log: reduced boilerplate code in ff strategy for applying activation function, added layer specific AFs, improved error derivative calculation
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1707564&r1=1707563&r2=1707564&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Thu Oct 8 15:03:49 2015 @@ -18,6 +18,8 @@ */ package org.apache.yay.core; +import java.util.Arrays; + import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; @@ -107,19 +109,13 @@ class DefaultDerivativeUpdateFunction im private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) { RealVector output = activations[activations.length - 1]; -// Double[] sampleOutput = new Double[output.getDimension()]; Double[] actualOutput = trainingExample.getOutput(); -// int sampleOutputIntValue = actualOutput.intValue(); -// if (sampleOutputIntValue < sampleOutput.length) { -// sampleOutput[sampleOutputIntValue] = 1d; -// } else if (sampleOutput.length == 1) { -// sampleOutput[0] = actualOutput; -// } else { -// throw new RuntimeException("problem with multiclass output mapping"); -// } RealVector learnedOutputRealVector = new ArrayRealVector(actualOutput); // turn example output to a vector - // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a) - return output.subtract(learnedOutputRealVector); + double[] ones = new double[output.getDimension()]; + Arrays.fill(ones, 1d); + + // error calculation -> er_a = out_a * (1 - out_a) * (tgt_a - out_a) (was: output.subtract(learnedOutputRealVector) + return output.ebeMultiply(new ArrayRealVector(ones).subtract(output)).ebeMultiply(output.subtract(learnedOutputRealVector)); } } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1707564&r1=1707563&r2=1707564&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Thu Oct 8 15:03:49 2015 @@ -20,10 +20,14 @@ package org.apache.yay.core; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; +import java.util.Map; + import org.apache.commons.collections.CollectionUtils; import org.apache.commons.collections.Transformer; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealMatrixChangingVisitor; import org.apache.commons.math3.linear.RealVector; import org.apache.yay.ActivationFunction; import org.apache.yay.PredictionStrategy; @@ -40,10 +44,15 @@ import org.apache.yay.core.utils.Convers */ public class FeedForwardStrategy implements PredictionStrategy<Double, Double> { - private final ActivationFunction<Double> activationFunction; + private final Map<Integer, ActivationFunction<Double>> activationFunctionMap; public FeedForwardStrategy(ActivationFunction<Double> activationFunction) { - this.activationFunction = activationFunction; + this.activationFunctionMap = new HashMap<Integer, ActivationFunction<Double>>(); + this.activationFunctionMap.put(0, activationFunction); + } + + public FeedForwardStrategy(Map<Integer, ActivationFunction<Double>> activationFunctionMap) { + this.activationFunctionMap = activationFunctionMap; } @Override @@ -69,32 +78,27 @@ public class FeedForwardStrategy impleme x = x.multiply(currentWeightsMatrix.transpose()); // apply the activation function to each element in the matrix - for (int i = 0; i < x.getRowDimension(); i++) { - double[] doubles = x.getRow(i); - final ArrayList<Double> row = new ArrayList<Double>(doubles.length); - for (int j = 0; j < doubles.length; j++) { - row.add(j, doubles[j]); + int idx = activationFunctionMap.size() == realMatrixSet.length ? w : 0; + final ActivationFunction<Double> af = activationFunctionMap.get(idx); + x.walkInRowOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + return af.apply(value); } - // TODO : see if bias term is handled correctly here - CollectionUtils.transform(row, new ActivationRowTransformer()); - double[] finRow = new double[row.size()]; - for (int h = 0; h < finRow.length; h++) { - finRow[h] = row.get(h); + + @Override + public double end() { + return 0; } - x.setRow(i, finRow); - } + }); debugOutput[w] = x.getRowVector(0); } return debugOutput; } - private class ActivationRowTransformer implements Transformer { - @Override - public Object transform(Object input) { - assert input instanceof Double; - final Double d = (Double) input; - return activationFunction.apply(d); - } - } - } \ No newline at end of file Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?rev=1707564&r1=1707563&r2=1707564&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java Thu Oct 8 15:03:49 2015 @@ -18,24 +18,21 @@ */ package org.apache.yay.core; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collection; -import java.util.Random; - import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; -import org.apache.yay.CreationException; -import org.apache.yay.Feature; -import org.apache.yay.Input; -import org.apache.yay.LearningStrategy; -import org.apache.yay.NeuralNetwork; -import org.apache.yay.TrainingExample; -import org.apache.yay.TrainingSet; +import org.apache.commons.math3.ml.distance.CanberraDistance; +import org.apache.commons.math3.ml.distance.DistanceMeasure; +import org.apache.commons.math3.ml.distance.EuclideanDistance; +import org.apache.yay.*; import org.apache.yay.core.utils.ExamplesFactory; import org.junit.Test; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Random; + import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * Integration test for NN @@ -134,15 +131,20 @@ public class NeuralNetworkIntegrationTes int noOfFeatures = randomWeights[0].getColumnDimension() - 1; Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, noOfFeatures, noOfOutputs); nn.learn(new TrainingSet<Double, Double>(samples)); + DistanceMeasure distanceMeasure = new CanberraDistance(); for (TrainingExample<Double, Double> sample : samples) { Double[] predictedOutput = nn.predict(sample); + double[] a1 = new double[predictedOutput.length]; + for (int i = 0; i < a1.length; i++) { + a1[i] = predictedOutput[i]; + } Double[] expectedOutput = sample.getOutput(); - boolean equals = Arrays.equals(expectedOutput, predictedOutput); -// if (!equals) { -// System.err.println(Arrays.toString(expectedOutput) + " vs " + Arrays.toString(predictedOutput)); -// } else { -// System.err.println("equals!"); -// } + double[] a2 = new double[expectedOutput.length]; + for (int i = 0; i < a2.length; i++) { + a2[i] = expectedOutput[i]; + } + double dist = distanceMeasure.compute(a1, a2); + assertTrue("expected and actual outputs are distant " + dist, dist < 10d); } } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java?rev=1707564&r1=1707563&r2=1707564&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/Word2VecTest.java Thu Oct 8 15:03:49 2015 @@ -31,21 +31,21 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; +import java.util.Map; import java.util.Random; import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.MatrixUtils; import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.linear.SingularValueDecomposition; +import org.apache.yay.ActivationFunction; import org.apache.yay.Feature; import org.apache.yay.Input; import org.apache.yay.NeuralNetwork; import org.apache.yay.TrainingExample; import org.apache.yay.TrainingSet; -import org.apache.yay.core.utils.ConversionUtils; -import org.apache.yay.core.utils.ExamplesFactory; import org.junit.Test; import static org.junit.Assert.*; @@ -70,13 +70,17 @@ public class Word2VecTest { TrainingExample<Double, Double> next = trainingSet.iterator().next(); int inputSize = next.getFeatures().size() ; int outputSize = next.getOutput().length; - int n = new Random().nextInt(20); + int n = new Random().nextInt(20) + 5; RealMatrix[] randomWeights = createRandomWeights(inputSize, n, outputSize); - FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(new IdentityActivationFunction<Double>()); - BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.0005d, -1, + Map<Integer, ActivationFunction<Double>> activationFunctions = new HashMap<Integer, ActivationFunction<Double>>(); + activationFunctions.put(0, new IdentityActivationFunction<Double>()); + // TODO : place a softmax activation for the output layer + activationFunctions.put(0, new IdentityActivationFunction<Double>()); + FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(activationFunctions); + BackPropagationLearningStrategy learningStrategy = new BackPropagationLearningStrategy(0.05d, 10, BackPropagationLearningStrategy.DEFAULT_THRESHOLD, predictionStrategy, new LMSCostFunction(), - 30); + 80); NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(randomWeights, learningStrategy, predictionStrategy); neuralNetwork.learn(trainingSet); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org