Author: tommaso Date: Sun Oct 4 06:21:16 2015 New Revision: 1706651 URL: http://svn.apache.org/viewvc?rev=1706651&view=rev Log: enabling multiple NN output neurons
Added: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java - copied, changed from r1705721, labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java Removed: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/Hypothesis.java Sun Oct 4 06:21:16 2015 @@ -18,6 +18,8 @@ */ package org.apache.yay; +import java.util.List; + /** * In machine learning an hypothesis is the output of applying a learning * algorithm to a training set, an hypothesis maps new inputs to possible outputs. @@ -45,7 +47,7 @@ public interface Hypothesis<T, I, O> { * @return the predicted output * @throws PredictionException if any error occurs during the prediction phase */ - O predict(Input<I> input) throws PredictionException; + O[] predict(Input<I> input) throws PredictionException; /** * Let this <code>Hypothesis</code> learn by experience, in the form of a {@link org.apache.yay.TrainingSet} Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Sun Oct 4 06:21:16 2015 @@ -27,13 +27,4 @@ import org.apache.commons.math3.linear.R */ public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> { - /** - * Predict the output for a given input - * - * @param input the input to evaluate - * @return the predicted output - * @throws PredictionException if any error occurs during the prediction phase - */ - Double[] getOutputVector(Input<Double> input) throws PredictionException; - } Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/TrainingExample.java Sun Oct 4 06:21:16 2015 @@ -28,6 +28,6 @@ public interface TrainingExample<F, O> e * * @return the output */ - O getOutput(); + O[] getOutput(); } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Sun Oct 4 06:21:16 2015 @@ -19,13 +19,10 @@ package org.apache.yay.core; import java.util.Arrays; -import java.util.DoubleSummaryStatistics; import java.util.Iterator; import org.apache.commons.math3.linear.Array2DRowRealMatrix; -import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.linear.RealVector; import org.apache.yay.CostFunction; import org.apache.yay.DerivativeUpdateFunction; import org.apache.yay.LearningStrategy; @@ -34,7 +31,6 @@ import org.apache.yay.PredictionStrategy import org.apache.yay.TrainingExample; import org.apache.yay.TrainingSet; import org.apache.yay.WeightLearningException; -import org.apache.yay.core.utils.ConversionUtils; /** * Back propagation learning algorithm for neural networks implementation (see @@ -71,7 +67,7 @@ public class BackPropagationLearningStra public BackPropagationLearningStrategy() { // commonly used defaults - this.predictionStrategy = new FeedForwardStrategy(new SigmoidFunction()); + this.predictionStrategy = new FeedForwardStrategy(new TanhFunction()); this.costFunction = new LogisticRegressionCostFunction(); this.alpha = DEFAULT_ALPHA; this.threshold = DEFAULT_THRESHOLD; @@ -85,7 +81,7 @@ public class BackPropagationLearningStra try { int iterations = 0; - NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy, new MaxSelectionFunction<Double>()); + NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy); Iterator<TrainingExample<Double, Double>> iterator = trainingExamples.iterator(); double cost = Double.MAX_VALUE; @@ -142,7 +138,10 @@ public class BackPropagationLearningStra double[][] updatedWeights = weightsMatrixSet[l].getData(); for (int i = 0; i < updatedWeights.length; i++) { for (int j = 0; j < updatedWeights[i].length; j++) { - updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j]; + double curVal = updatedWeights[i][j]; + if (curVal > 0d && curVal < 1d) { + updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j]; + } } } updatedParameters[l] = new Array2DRowRealMatrix(updatedWeights); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Sun Oct 4 06:21:16 2015 @@ -62,7 +62,7 @@ public class BasicPerceptron implements Collection<Double> doubles = ConversionUtils.toValuesCollection(example.getFeatures()); Double[] inputs = doubles.toArray(new Double[doubles.size()]); Double calculatedOutput = perceptronNeuron.elaborate(inputs); - int diff = calculatedOutput.compareTo(example.getOutput()); + int diff = calculatedOutput.compareTo(example.getOutput()[0]); if (diff > 0) { for (int i = 0; i < currentWeights.length; i++) { currentWeights[i] += inputs[i]; @@ -90,17 +90,10 @@ public class BasicPerceptron implements } @Override - public Double predict(Input<Double> input) throws PredictionException { - return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( + public Double[] predict(Input<Double> input) throws PredictionException { + Double output = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( new Double[input.getFeatures().size()])); + return new Double[]{output}; } - @Override - public Double[] getOutputVector(Input<Double> input) throws PredictionException { - Double elaborate = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( - new Double[input.getFeatures().size()])); - Double[] ar = new Double[1]; - ar[0] = elaborate; - return ar; - } } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Sun Oct 4 06:21:16 2015 @@ -78,10 +78,6 @@ public class DefaultDerivativeUpdateFunc count++; } - return createDerivatives(triangle, count); - } - - private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) { RealMatrix[] derivatives = new RealMatrix[triangle.length]; for (int i = 0; i < triangle.length; i++) { // TODO : introduce regularization diversification on bias term (currently not regularized) @@ -111,16 +107,17 @@ public class DefaultDerivativeUpdateFunc private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) { RealVector output = activations[activations.length - 1]; - Double[] sampleOutput = new Double[output.getDimension()]; - int sampleOutputIntValue = trainingExample.getOutput().intValue(); - if (sampleOutputIntValue < sampleOutput.length) { - sampleOutput[sampleOutputIntValue] = 1d; - } else if (sampleOutput.length == 1) { - sampleOutput[0] = trainingExample.getOutput(); - } else { - throw new RuntimeException("problem with multiclass output mapping"); - } - RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector +// Double[] sampleOutput = new Double[output.getDimension()]; + Double[] actualOutput = trainingExample.getOutput(); +// int sampleOutputIntValue = actualOutput.intValue(); +// if (sampleOutputIntValue < sampleOutput.length) { +// sampleOutput[sampleOutputIntValue] = 1d; +// } else if (sampleOutput.length == 1) { +// sampleOutput[0] = actualOutput; +// } else { +// throw new RuntimeException("problem with multiclass output mapping"); +// } + RealVector learnedOutputRealVector = new ArrayRealVector(actualOutput); // turn example output to a vector // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a) return output.subtract(learnedOutputRealVector); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java Sun Oct 4 06:21:16 2015 @@ -74,11 +74,12 @@ public class LogisticRegressionCostFunct Double res = 0d; for (TrainingExample<Double, Double> input : trainingExamples) { - // TODO : handle this for multiple outputs (multi class classification) - Double predictedOutput = hypothesis.predict(input); - Double sampleOutput = input.getOutput(); - res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput) - * Math.log(1d - predictedOutput); + Double[] predictedOutput = hypothesis.predict(input); + Double[] sampleOutput = input.getOutput(); + for (int i = 0; i < predictedOutput.length; i++) { + res += sampleOutput[i] * Math.log(predictedOutput[i]) + (1d - sampleOutput[i]) + * Math.log(1d - predictedOutput[i]); + } } return (-1d / trainingExamples.length) * res; } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Sun Oct 4 06:21:16 2015 @@ -18,11 +18,8 @@ */ package org.apache.yay.core; -import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import org.apache.commons.math3.linear.RealMatrix; -import org.apache.commons.math3.linear.RealVector; import org.apache.yay.CreationException; import org.apache.yay.Input; import org.apache.yay.LearningException; @@ -30,7 +27,6 @@ import org.apache.yay.LearningStrategy; import org.apache.yay.NeuralNetwork; import org.apache.yay.PredictionException; import org.apache.yay.PredictionStrategy; -import org.apache.yay.SelectionFunction; import org.apache.yay.TrainingSet; import org.apache.yay.WeightLearningException; import org.apache.yay.core.utils.ConversionUtils; @@ -51,12 +47,10 @@ public class NeuralNetworkFactory { * @throws org.apache.yay.CreationException */ public static NeuralNetwork create(final RealMatrix[] realMatrixSet, final LearningStrategy<Double, Double> learningStrategy, - final PredictionStrategy<Double, Double> predictionStrategy, - final SelectionFunction<Collection<Double>, Double> selectionFunction) throws CreationException { + final PredictionStrategy<Double, Double> predictionStrategy) throws CreationException { return new NeuralNetwork() { - @Override - public Double[] getOutputVector(Input<Double> input) throws PredictionException { + private Double[] getOutputVector(Input<Double> input) throws PredictionException { Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures()); return predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet); } @@ -83,10 +77,9 @@ public class NeuralNetworkFactory { } @Override - public Double predict(Input<Double> input) throws PredictionException { + public Double[] predict(Input<Double> input) throws PredictionException { try { - Double[] doubles = getOutputVector(input); - return selectionFunction.selectOutput(Arrays.asList(doubles)); + return getOutputVector(input); } catch (Exception e) { throw new PredictionException(e); } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java Sun Oct 4 06:21:16 2015 @@ -39,22 +39,22 @@ public class ExamplesFactory { } @Override - public Double getOutput() { - return output; + public Double[] getOutput() { + return new Double[]{output}; } }; } - public static TrainingExample<Double, Collection<Double[]>> createSGMExample(final Collection<Double[]> output, + public static TrainingExample<Double, Double> createDoubleArrayTrainingExample(final Double[] output, final Double... featuresValues) { - return new TrainingExample<Double, Collection<Double[]>>() { + return new TrainingExample<Double, Double>() { @Override public ArrayList<Feature<Double>> getFeatures() { return doublesToFeatureVector(featuresValues); } @Override - public Collection<Double[]> getOutput() { + public Double[] getOutput() { return output; } }; Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Sun Oct 4 06:21:16 2015 @@ -41,9 +41,9 @@ public class BackPropagationLearningStra public void testLearningWithRandomNetwork() throws Exception { BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy(); - RealMatrix[] initialWeights = createRandomWeights(); + RealMatrix[] initialWeights = createRandomWeights(2); - Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1); + Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1, 2); TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples); RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); @@ -62,9 +62,9 @@ public class BackPropagationLearningStra BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy(alpha, threshold, predictionStrategy, costFunction); - RealMatrix[] initialWeights = createRandomWeights(); + RealMatrix[] initialWeights = createRandomWeights(10); - Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1); + Collection<TrainingExample<Double, Double>> samples = createSamples(1000, initialWeights[0].getColumnDimension() - 1, 10); TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples); RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); @@ -82,7 +82,7 @@ public class BackPropagationLearningStra } } - private RealMatrix[] createRandomWeights() { + private RealMatrix[] createRandomWeights(int outputSize) { Random r = new Random(); int weightsCount = (Math.abs(r.nextInt()) % 5) + 2; @@ -95,7 +95,7 @@ public class BackPropagationLearningStra } else { cols = initialWeights[i - 1].getRowDimension(); if (i == weightsCount - 1) { - rows = 1; + rows = outputSize; } } double[][] d = new double[rows][cols]; @@ -137,7 +137,7 @@ public class BackPropagationLearningStra initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); // 4 x 4 initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}}); // 1 x 4 - Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2); + Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2, 1); TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples); RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); @@ -169,7 +169,7 @@ public class BackPropagationLearningStra initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}}); - Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2); + Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2, 1); TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples); RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); @@ -201,7 +201,7 @@ public class BackPropagationLearningStra {1d, Math.random(), Math.random(), Math.random()} }); - Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, 2); + Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, 2, 1); TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples); RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); @@ -211,14 +211,18 @@ public class BackPropagationLearningStra assertFalse(learntWeights[2].equals(initialWeights[2])); } - private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures) { + private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures, int noOfOutputs) { Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size); for (int i = 0; i < size; i++) { Double[] featureValues = new Double[noOfFeatures]; for (int j = 0; j < noOfFeatures; j++) { featureValues[j] = Math.random(); } - trainingExamples.add(ExamplesFactory.createDoubleTrainingExample(1d, featureValues)); + Double[] outputs = new Double[noOfOutputs]; + for (int j = 0; j < outputs.length; j++) { + outputs[j] = Math.random(); + } + trainingExamples.add(ExamplesFactory.createDoubleArrayTrainingExample(outputs, featureValues)); } return trainingExamples; } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java Sun Oct 4 06:21:16 2015 @@ -86,7 +86,7 @@ public class BasicPerceptronTest { r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble()); basicPerceptron.learn(bigDataset); - Double output = basicPerceptron.predict(createInput(r)); + Double output = basicPerceptron.predict(createInput(r))[0]; assertTrue(output == 0d || output == 1d); } @@ -102,7 +102,7 @@ public class BasicPerceptronTest { r.nextDouble(), r.nextDouble()); basicPerceptron.learn(bigDataset); TrainingExample<Double, Double> input = createInput(r); - Double output = basicPerceptron.predict(input); + Double output = basicPerceptron.predict(input)[0]; assertTrue(output == 0d || output == 1d); basicPerceptron.learn(createTrainingExample(1d, r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), @@ -110,7 +110,7 @@ public class BasicPerceptronTest { r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble(), r.nextDouble())); - Double secondOutput = basicPerceptron.predict(input); + Double secondOutput = basicPerceptron.predict(input)[0]; assertTrue(secondOutput == 0d || secondOutput == 1d); } @@ -135,7 +135,7 @@ public class BasicPerceptronTest { public void testLearnAndPredictWithSmallDataset() throws Exception { BasicPerceptron basicPerceptron = new BasicPerceptron(1d, 2d, 3d, 4d); basicPerceptron.learn(smallDataset); - Double output = basicPerceptron.predict(createTrainingExample(null, 1d, 6d, 0.4d)); + Double output = basicPerceptron.predict(createTrainingExample(null, 1d, 6d, 0.4d))[0]; assertEquals(Double.valueOf(1d), output); } @@ -157,8 +157,8 @@ public class BasicPerceptronTest { } @Override - public Double getOutput() { - return output; + public Double[] getOutput() { + return new Double[]{output}; } }; } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java?rev=1706651&r1=1706650&r2=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/LogisticRegressionCostFunctionTest.java Sun Oct 4 06:21:16 2015 @@ -64,8 +64,7 @@ public class LogisticRegressionCostFunct final RealMatrix[] orWeightsMatrixSet = new RealMatrix[]{singleOrLayerWeights}; final NeuralNetwork neuralNetwork = NeuralNetworkFactory.create(orWeightsMatrixSet, - new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy( - new SigmoidFunction()), new MaxSelectionFunction<Double>()); + new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(new SigmoidFunction())); Double cost = costFunction.calculateAggregatedCost(trainingSet, neuralNetwork); assertTrue("cost should not be negative", cost > 0d); Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java (from r1705721, labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java&r1=1705721&r2=1706651&rev=1706651&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java Sun Oct 4 06:21:16 2015 @@ -19,31 +19,39 @@ package org.apache.yay.core; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Random; + import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; import org.apache.yay.CreationException; import org.apache.yay.Feature; import org.apache.yay.Input; +import org.apache.yay.LearningStrategy; import org.apache.yay.NeuralNetwork; +import org.apache.yay.TrainingExample; +import org.apache.yay.TrainingSet; +import org.apache.yay.core.utils.ExamplesFactory; import org.junit.Test; import static org.junit.Assert.assertEquals; /** - * Testcase for {@link org.apache.yay.core.NeuralNetworkFactory} + * Integration test for NN */ -public class NeuralNetworkFactoryTest { +public class NeuralNetworkIntegrationTest { @Test public void andNNCreationTest() throws Exception { double[][] weights = {{-30d, 20d, 20d}}; RealMatrix singleAndLayerWeights = new Array2DRowRealMatrix(weights); RealMatrix[] andRealMatrixSet = new RealMatrix[]{singleAndLayerWeights}; - NeuralNetwork andNN = createFFNN(andRealMatrixSet); - assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d)))); - assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d)))); - assertEquals(0l, Math.round(andNN.predict(createSample(0d, 0d)))); - assertEquals(1l, Math.round(andNN.predict(createSample(1d, 1d)))); + NeuralNetwork andNN = createNN(andRealMatrixSet, new VoidLearningStrategy<Double, Double>()); + assertEquals(0l, Math.round(andNN.predict(createSample(1d, 0d))[0])); + assertEquals(0l, Math.round(andNN.predict(createSample(0d, 1d))[0])); + assertEquals(0l, Math.round(andNN.predict(createSample(0d, 0d))[0])); + assertEquals(1l, Math.round(andNN.predict(createSample(1d, 1d))[0])); } @Test @@ -51,11 +59,11 @@ public class NeuralNetworkFactoryTest { double[][] weights = {{-10d, 20d, 20d}}; RealMatrix singleOrLayerWeights = new Array2DRowRealMatrix(weights); RealMatrix[] orRealMatrixSet = new RealMatrix[]{singleOrLayerWeights}; - NeuralNetwork orNN = createFFNN(orRealMatrixSet); - assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d)))); - assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d)))); - assertEquals(0l, Math.round(orNN.predict(createSample(0d, 0d)))); - assertEquals(1l, Math.round(orNN.predict(createSample(1d, 1d)))); + NeuralNetwork orNN = createNN(orRealMatrixSet, new VoidLearningStrategy<Double, Double>()); + assertEquals(1l, Math.round(orNN.predict(createSample(1d, 0d))[0])); + assertEquals(1l, Math.round(orNN.predict(createSample(0d, 1d))[0])); + assertEquals(0l, Math.round(orNN.predict(createSample(0d, 0d))[0])); + assertEquals(1l, Math.round(orNN.predict(createSample(1d, 1d))[0])); } @Test @@ -63,9 +71,9 @@ public class NeuralNetworkFactoryTest { double[][] weights = {{10d, -20d}}; RealMatrix singleNotLayerWeights = new Array2DRowRealMatrix(weights); RealMatrix[] notRealMatrixSet = new RealMatrix[]{singleNotLayerWeights}; - NeuralNetwork orNN = createFFNN(notRealMatrixSet); - assertEquals(1l, Math.round(orNN.predict(createSample(0d)))); - assertEquals(0l, Math.round(orNN.predict(createSample(1d)))); + NeuralNetwork orNN = createNN(notRealMatrixSet, new VoidLearningStrategy<Double, Double>()); + assertEquals(1l, Math.round(orNN.predict(createSample(0d))[0])); + assertEquals(0l, Math.round(orNN.predict(createSample(1d))[0])); } @Test @@ -73,11 +81,11 @@ public class NeuralNetworkFactoryTest { RealMatrix firstNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{0, 0, 0}, {-30d, 20d, 20d}, {10d, -20d, -20d}}); RealMatrix secondNorLayerWeights = new Array2DRowRealMatrix(new double[][]{{-10d, 20d, 20d}}); RealMatrix[] norRealMatrixSet = new RealMatrix[]{firstNorLayerWeights, secondNorLayerWeights}; - NeuralNetwork norNN = createFFNN(norRealMatrixSet); - assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d)))); - assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d)))); - assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d)))); - assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d)))); + NeuralNetwork norNN = createNN(norRealMatrixSet, new VoidLearningStrategy<Double, Double>()); + assertEquals(0l, Math.round(norNN.predict(createSample(1d, 0d))[0])); + assertEquals(0l, Math.round(norNN.predict(createSample(0d, 1d))[0])); + assertEquals(1l, Math.round(norNN.predict(createSample(0d, 0d))[0])); + assertEquals(1l, Math.round(norNN.predict(createSample(1d, 1d))[0])); } @Test @@ -87,17 +95,17 @@ public class NeuralNetworkFactoryTest { RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer}; - NeuralNetwork neuralNetwork = createFFNN(RealMatrixes); + NeuralNetwork neuralNetwork = createNN(RealMatrixes, new VoidLearningStrategy<Double, Double>()); - Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d)); + Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d))[0]; assertEquals(1l, Math.round(prdictedValue)); assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue); } - private NeuralNetwork createFFNN(RealMatrix[] realMatrixes) + private NeuralNetwork createNN(RealMatrix[] realMatrixes, LearningStrategy<Double, Double> learningStrategy) throws CreationException { - return NeuralNetworkFactory.create(realMatrixes, new VoidLearningStrategy<Double, Double>(), - new FeedForwardStrategy(new SigmoidFunction()), new MaxSelectionFunction<Double>()); + return NeuralNetworkFactory.create(realMatrixes, learningStrategy, + new FeedForwardStrategy(new SigmoidFunction())); } private Input<Double> createSample(final Double... params) { @@ -117,4 +125,87 @@ public class NeuralNetworkFactoryTest { } }; } + + @Test + public void testRandomNNEvaluation() throws Exception { + int noOfOutputs = 10; + RealMatrix[] randomWeights = createRandomWeights(noOfOutputs); + NeuralNetwork nn = createNN(randomWeights, new BackPropagationLearningStrategy()); + int noOfFeatures = randomWeights[0].getColumnDimension() - 1; + Collection<TrainingExample<Double, Double>> samples = createSamples(1000000, noOfFeatures, noOfOutputs); + nn.learn(new TrainingSet<Double, Double>(samples)); + for (TrainingExample<Double, Double> sample : samples) { + Double[] predictedOutput = nn.predict(sample); + Double[] expectedOutput = sample.getOutput(); + boolean equals = Arrays.equals(expectedOutput, predictedOutput); +// if (!equals) { +// System.err.println(Arrays.toString(expectedOutput) + " vs " + Arrays.toString(predictedOutput)); +// } else { +// System.err.println("equals!"); +// } + } + + } + + private Collection<TrainingExample<Double, Double>> createSamples(int size, int noOfFeatures, int noOfOutputs) { + Random r = new Random(); + Collection<TrainingExample<Double, Double>> trainingExamples = new ArrayList<TrainingExample<Double, Double>>(size); + for (int i = 0; i < size; i++) { + Double[] featureValues = new Double[noOfFeatures]; + for (int j = 0; j < noOfFeatures; j++) { + featureValues[j] = r.nextInt(100) / 101d; + } + Double[] outputs = new Double[noOfOutputs]; + for (int j = 0; j < outputs.length; j++) { + outputs[j] = r.nextInt(100) / 101d; + } + trainingExamples.add(ExamplesFactory.createDoubleArrayTrainingExample(outputs, featureValues)); + } + return trainingExamples; + } + + private RealMatrix[] createRandomWeights(int outputSize) { + Random r = new Random(); + int weightsCount = (Math.abs(r.nextInt()) % 5) + 2; + + RealMatrix[] initialWeights = new RealMatrix[weightsCount]; + for (int i = 0; i < weightsCount; i++) { + int rows = (Math.abs(r.nextInt()) % 4) + 2; + int cols; + if (i == 0) { + cols = (Math.abs(r.nextInt()) % 4) + 2; + } else { + cols = initialWeights[i - 1].getRowDimension(); + if (i == weightsCount - 1) { + rows = outputSize; + } + } + double[][] d = new double[rows][cols]; + for (int c = 0; c < cols; c++) { + if (i == weightsCount - 1) { + if (c == 0) { + d[0][c] = 1d; + } else { + d[0][c] = r.nextDouble(); + } + } else { + d[0][c] = 0; + } + } + + for (int k = 1; k < rows; k++) { + for (int j = 0; j < cols; j++) { + double val; + if (j == 0) { + val = 1d; + } else { + val = r.nextDouble(); + } + d[k][j] = val; + } + } + initialWeights[i] = new Array2DRowRealMatrix(d); + } + return initialWeights; + } } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org