Author: tommaso Date: Mon Sep 28 16:49:57 2015 New Revision: 1705721 URL: http://svn.apache.org/viewvc?rev=1705721&view=rev Log: switch from batch to stochastic GD in backprop, abstracted derivative update function
Added: labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Removed: labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java Added: labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java (added) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java Mon Sep 28 16:49:57 2015 @@ -0,0 +1,30 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.yay.TrainingSet; + +/** + * Derivatives update function + */ +public interface DerivativeUpdateFunction<F,O> { + + RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O> trainingSet); +} Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Mon Sep 28 16:49:57 2015 @@ -21,8 +21,19 @@ package org.apache.yay; import org.apache.commons.math3.linear.RealMatrix; /** - * A neural network is a layered connected graph of elaboration units + * A Neural Network is a layered connected graph of elaboration units. + * + * It takes a double vector as input and produces a double vector as output. */ public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> { + /** + * Predict the output for a given input + * + * @param input the input to evaluate + * @return the predicted output + * @throws PredictionException if any error occurs during the prediction phase + */ + Double[] getOutputVector(Input<Double> input) throws PredictionException; + } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Mon Sep 28 16:49:57 2015 @@ -19,6 +19,7 @@ package org.apache.yay.core; import java.util.Arrays; +import java.util.DoubleSummaryStatistics; import java.util.Iterator; import org.apache.commons.math3.linear.Array2DRowRealMatrix; @@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealVector; import org.apache.yay.CostFunction; +import org.apache.yay.DerivativeUpdateFunction; import org.apache.yay.LearningStrategy; import org.apache.yay.NeuralNetwork; import org.apache.yay.PredictionStrategy; @@ -46,6 +48,7 @@ public class BackPropagationLearningStra private final PredictionStrategy<Double, Double> predictionStrategy; private final CostFunction<RealMatrix, Double, Double> costFunction; + private final DerivativeUpdateFunction<Double, Double> derivativeUpdateFunction; private final double alpha; private final double threshold; private final int batch; @@ -63,6 +66,7 @@ public class BackPropagationLearningStra this.alpha = alpha; this.threshold = threshold; this.batch = batch; + this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy); } public BackPropagationLearningStrategy() { @@ -72,6 +76,7 @@ public class BackPropagationLearningStra this.alpha = DEFAULT_ALPHA; this.threshold = DEFAULT_THRESHOLD; this.batch = 1; + this.derivativeUpdateFunction = new DefaultDerivativeUpdateFunction(predictionStrategy); } @Override @@ -114,7 +119,7 @@ public class BackPropagationLearningStra cost = newCost; // calculate the derivatives to update the parameters - RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, samples); + RealMatrix[] derivatives = derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples); // calculate the updated parameters updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha); @@ -131,48 +136,6 @@ public class BackPropagationLearningStra return updatedWeights; } - private RealMatrix[] calculateDerivatives(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws WeightLearningException { - // set up the accumulator matrix(es) - RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length]; - RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length]; - - int noOfMatrixes = weightsMatrixSet.length - 1; - double count = 0; - for (TrainingExample<Double, Double> trainingExample : trainingExamples) { - try { - // get activations from feed forward propagation - RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet); - - // calculate output error (corresponding to the last delta^l) - RealVector nextLayerDelta = calculateOutputError(trainingExample, activations); - - deltaVectors[noOfMatrixes] = nextLayerDelta; - - // back prop the error and update the deltas accordingly - for (int l = noOfMatrixes; l > 0; l--) { - RealVector currentActivationsVector = activations[l - 1]; - nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], currentActivationsVector, nextLayerDelta); - - // collect delta vectors for this example - deltaVectors[l - 1] = nextLayerDelta; - } - - RealVector[] newActivations = new RealVector[activations.length]; - newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); - System.arraycopy(activations, 0, newActivations, 1, activations.length - 1); - - // update triangle (big delta matrix) - updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet); - - } catch (Exception e) { - throw new WeightLearningException("error during derivatives calculation", e); - } - count++; - } - - return createDerivatives(triangle, count); - } - private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet, RealMatrix[] derivatives, double alpha) { RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length]; for (int l = 0; l < weightsMatrixSet.length; l++) { @@ -187,48 +150,4 @@ public class BackPropagationLearningStra return updatedParameters; } - private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) { - RealMatrix[] derivatives = new RealMatrix[triangle.length]; - for (int i = 0; i < triangle.length; i++) { - // TODO : introduce regularization diversification on bias term (currently not regularized) - derivatives[i] = triangle[i].scalarMultiply(1d / count); - } - return derivatives; - } - - private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) { - for (int l = weightsMatrixSet.length - 1; l >= 0; l--) { - RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]); - if (triangle[l] == null) { - triangle[l] = realMatrix; - } else { - triangle[l] = triangle[l].add(realMatrix); - } - } - } - - private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) { - // TODO : remove the bias term from the error calculations - ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d); - RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l) - return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz); - } - - private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) { - RealVector output = activations[activations.length - 1]; - - Double[] sampleOutput = new Double[output.getDimension()]; - int sampleOutputIntValue = trainingExample.getOutput().intValue(); - if (sampleOutputIntValue < sampleOutput.length) { - sampleOutput[sampleOutputIntValue] = 1d; - } else if (sampleOutput.length == 1) { - sampleOutput[0] = trainingExample.getOutput(); - } else { - throw new RuntimeException("problem with multiclass output mapping"); - } - RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector - - // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a) - return output.subtract(learnedOutputRealVector); - } } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java Mon Sep 28 16:49:57 2015 @@ -94,4 +94,13 @@ public class BasicPerceptron implements return perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( new Double[input.getFeatures().size()])); } + + @Override + public Double[] getOutputVector(Input<Double> input) throws PredictionException { + Double elaborate = perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( + new Double[input.getFeatures().size()])); + Double[] ar = new Double[1]; + ar[0] = elaborate; + return ar; + } } Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java Mon Sep 28 16:49:57 2015 @@ -0,0 +1,128 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay.core; + +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; +import org.apache.yay.DerivativeUpdateFunction; +import org.apache.yay.PredictionStrategy; +import org.apache.yay.TrainingExample; +import org.apache.yay.TrainingSet; +import org.apache.yay.core.utils.ConversionUtils; + +/** + * Default derivatives update function + */ +public class DefaultDerivativeUpdateFunction implements DerivativeUpdateFunction<Double, Double> { + + private final PredictionStrategy<Double, Double> predictionStrategy; + + public DefaultDerivativeUpdateFunction(PredictionStrategy<Double, Double> predictionStrategy) { + this.predictionStrategy = predictionStrategy; + } + + @Override + public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) { + // set up the accumulator matrix(es) + RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length]; + RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length]; + + int noOfMatrixes = weightsMatrixSet.length - 1; + double count = 0; + for (TrainingExample<Double, Double> trainingExample : trainingExamples) { + try { + // get activations from feed forward propagation + RealVector[] activations = predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), weightsMatrixSet); + + // calculate output error (corresponding to the last delta^l) + RealVector nextLayerDelta = calculateOutputError(trainingExample, activations); + + deltaVectors[noOfMatrixes] = nextLayerDelta; + + // back prop the error and update the deltas accordingly + for (int l = noOfMatrixes; l > 0; l--) { + RealVector currentActivationsVector = activations[l - 1]; + nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], currentActivationsVector, nextLayerDelta); + + // collect delta vectors for this example + deltaVectors[l - 1] = nextLayerDelta; + } + + RealVector[] newActivations = new RealVector[activations.length]; + newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); + System.arraycopy(activations, 0, newActivations, 1, activations.length - 1); + + // update triangle (big delta matrix) + updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet); + + } catch (Exception e) { + throw new RuntimeException("error during derivatives calculation", e); + } + count++; + } + + return createDerivatives(triangle, count); + } + + private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) { + RealMatrix[] derivatives = new RealMatrix[triangle.length]; + for (int i = 0; i < triangle.length; i++) { + // TODO : introduce regularization diversification on bias term (currently not regularized) + derivatives[i] = triangle[i].scalarMultiply(1d / count); + } + return derivatives; + } + + private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) { + for (int l = weightsMatrixSet.length - 1; l >= 0; l--) { + RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]); + if (triangle[l] == null) { + triangle[l] = realMatrix; + } else { + triangle[l] = triangle[l].add(realMatrix); + } + } + } + + private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector activationsVector, RealVector nextLayerDelta) { + // TODO : remove the bias term from the error calculations + ArrayRealVector identity = new ArrayRealVector(activationsVector.getDimension(), 1d); + RealVector gz = activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l .* (1-a^l) + return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz); + } + + private RealVector calculateOutputError(TrainingExample<Double, Double> trainingExample, RealVector[] activations) { + RealVector output = activations[activations.length - 1]; + + Double[] sampleOutput = new Double[output.getDimension()]; + int sampleOutputIntValue = trainingExample.getOutput().intValue(); + if (sampleOutputIntValue < sampleOutput.length) { + sampleOutput[sampleOutputIntValue] = 1d; + } else if (sampleOutput.length == 1) { + sampleOutput[0] = trainingExample.getOutput(); + } else { + throw new RuntimeException("problem with multiclass output mapping"); + } + RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // turn example output to a vector + + // TODO : improve error calculation -> this could be er_a = out_a * (1 - out_a) * (tgt_a - out_a) + return output.subtract(learnedOutputRealVector); + } +} Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java Mon Sep 28 16:49:57 2015 @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - package org.apache.yay.core; import java.util.ArrayList; Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java Mon Sep 28 16:49:57 2015 @@ -18,9 +18,11 @@ */ package org.apache.yay.core; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealVector; import org.apache.yay.CreationException; import org.apache.yay.Input; import org.apache.yay.LearningException; @@ -53,6 +55,12 @@ public class NeuralNetworkFactory { final SelectionFunction<Collection<Double>, Double> selectionFunction) throws CreationException { return new NeuralNetwork() { + @Override + public Double[] getOutputVector(Input<Double> input) throws PredictionException { + Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures()); + return predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet); + } + private RealMatrix[] updatedRealMatrixSet = realMatrixSet; @Override @@ -77,8 +85,7 @@ public class NeuralNetworkFactory { @Override public Double predict(Input<Double> input) throws PredictionException { try { - Collection<Double> inputVector = ConversionUtils.toValuesCollection(input.getFeatures()); - Double[] doubles = predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet); + Double[] doubles = getOutputVector(input); return selectionFunction.selectOutput(Arrays.asList(doubles)); } catch (Exception e) { throw new PredictionException(e); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java Mon Sep 28 16:49:57 2015 @@ -19,6 +19,8 @@ package org.apache.yay.core.utils; import java.util.ArrayList; +import java.util.Collection; + import org.apache.yay.Feature; import org.apache.yay.Input; import org.apache.yay.TrainingExample; @@ -41,6 +43,21 @@ public class ExamplesFactory { return output; } }; + } + + public static TrainingExample<Double, Collection<Double[]>> createSGMExample(final Collection<Double[]> output, + final Double... featuresValues) { + return new TrainingExample<Double, Collection<Double[]>>() { + @Override + public ArrayList<Feature<Double>> getFeatures() { + return doublesToFeatureVector(featuresValues); + } + + @Override + public Collection<Double[]> getOutput() { + return output; + } + }; } public static Input<Double> createDoubleInput(final Double... featuresValues) { Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java Mon Sep 28 16:49:57 2015 @@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest { public void sampleCreationTest() throws Exception { RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}}); RealMatrix secondLayer = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 3d}}); + RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer}; + NeuralNetwork neuralNetwork = createFFNN(RealMatrixes); + Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d)); assertEquals(1l, Math.round(prdictedValue)); assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org