Author: tommaso
Date: Wed Jun 19 10:24:50 2013
New Revision: 1494537
URL: http://svn.apache.org/r1494537
Log:
first fully working sketch of backpropagation
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
---
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
(original)
+++
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
Wed Jun 19 10:24:50 2013
@@ -18,12 +18,15 @@
*/
package org.apache.yay.core;
+import java.util.Arrays;
+
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.CostFunction;
import org.apache.yay.LearningStrategy;
+import org.apache.yay.NeuralNetwork;
import org.apache.yay.PredictionStrategy;
import org.apache.yay.TrainingExample;
import org.apache.yay.TrainingSet;
@@ -36,60 +39,149 @@ import org.apache.yay.core.utils.Convers
*/
public class BackPropagationLearningStrategy implements
LearningStrategy<Double, Double> {
+ private static final double DEFAULT_THRESHOLD = 0.005;
+ private static final int MAX_ITERATIONS = 100000;
+ private static final double DEFAULT_ALPHA = 0.03;
+
private final PredictionStrategy<Double, Double> predictionStrategy;
- private CostFunction<RealMatrix, Double, Double> costFunction;
+ private final CostFunction<RealMatrix, Double, Double> costFunction;
+ private final double alpha;
+ private final double threshold;
+
- public BackPropagationLearningStrategy(PredictionStrategy<Double, Double>
predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) {
+ public BackPropagationLearningStrategy(double alpha, double threshold,
PredictionStrategy<Double, Double> predictionStrategy, CostFunction<RealMatrix,
Double, Double> costFunction) {
this.predictionStrategy = predictionStrategy;
this.costFunction = costFunction;
+ this.alpha = alpha;
+ this.threshold = threshold;
+ }
+
+ public BackPropagationLearningStrategy() {
+ // commonly used defaults
+ this.predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
+ this.costFunction = new LogisticRegressionCostFunction();
+ this.alpha = DEFAULT_ALPHA;
+ this.threshold = DEFAULT_THRESHOLD;
}
@Override
public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet,
TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
- // set up the accumulator matrix(es)
- RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+ RealMatrix[] updatedWeights = weightsMatrixSet;
+ try {
+ int iterations = 0;
+
+ NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet,
new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(new
SigmoidFunction()), new MaxSelectionFunction<Double>());
+
+ double cost = Double.MAX_VALUE;
+ while (true) {
+ // calculate cost
+ double newCost =
costFunction.calculateAggregatedCost(trainingExamples, hypothesis);
+
+ if (newCost > cost) {
+ throw new RuntimeException("failed to converge at iteration " +
iterations + " with alpha "+ alpha +" : cost going from " + cost + " to " +
newCost);
+ } else if (cost == newCost || newCost < threshold || iterations >
MAX_ITERATIONS) {
+ System.out.println("successfully converged after " + iterations + "
iterations with cost " + newCost + " and parameters " +
Arrays.toString(hypothesis.getParameters()));
+ break;
+ }
- int count = 0;
- for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
- try {
- // contains activation errors for the current training example
- int noOfMatrixes = weightsMatrixSet.length - 1;
+ // update registered cost
+ cost = newCost;
- // feed forward propagation
- RealVector[] activations =
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
weightsMatrixSet);
+ // calculate the derivatives to update the parameters
+ RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet,
trainingExamples);
- // calculate output error
- RealVector error = calculateOutputError(trainingExample, activations);
+ // calculate the updated parameters
+ updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha);
- RealVector nextLayerDelta = error;
+ // update parameters in the hypothesis
+ hypothesis.setParameters(updatedWeights);
- triangle[noOfMatrixes] = new
Array2DRowRealMatrix(weightsMatrixSet[noOfMatrixes].getColumnDimension(),
weightsMatrixSet[noOfMatrixes].getRowDimension());
- triangle[noOfMatrixes] =
triangle[noOfMatrixes].add(activations[noOfMatrixes - 1].outerProduct(error));
+ iterations++;
+ }
+ }
+ catch (Exception e) {
+ throw new WeightLearningException("error during backprop learning", e);
+ }
- // back prop the error and update the deltas accordingly
- for (int l = weightsMatrixSet.length - 1; l > 0; l--) {
- RealVector resultingDeltaVector =
calculateDeltaVector(weightsMatrixSet[l], activations[l-1], nextLayerDelta);
+ return updatedWeights;
+ }
- if (triangle[l] == null) {
- triangle[l] = new
Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(),
weightsMatrixSet[l].getColumnDimension());
+ private RealMatrix[] calculateDerivatives(RealMatrix[] weightsMatrixSet,
TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
+ // set up the accumulator matrix(es)
+ RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+ RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
+
+ int noOfMatrixes = weightsMatrixSet.length - 1;
+ double count = 0;
+ for (TrainingExample<Double, Double> trainingExample :
trainingExamples) {
+ try {
+ // get activations from feed forward propagation
+ RealVector[] activations =
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
weightsMatrixSet);
+
+ // calculate output error (corresponding to the last delta^l)
+ RealVector nextLayerDelta = calculateOutputError(trainingExample,
activations);
+
+ deltaVectors[noOfMatrixes] = nextLayerDelta;
+
+ // back prop the error and update the deltas accordingly
+ for (int l = noOfMatrixes; l > 0; l--) {
+ RealVector currentActivationsVector = activations[l - 1];
+ nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l],
currentActivationsVector, nextLayerDelta);
+
+ // collect delta vectors for this example
+ deltaVectors[l - 1] = nextLayerDelta;
+ }
+
+ RealVector[] newActivations = new RealVector[activations.length];
+ newActivations[0] =
ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
+ for (int k = 0; k < activations.length - 1; k++) {
+ newActivations[k+1] = activations[k];
+ }
+
+ // update triangle (big delta matrix)
+ updateTriangle(triangle, newActivations, deltaVectors,
weightsMatrixSet);
+
+ } catch (Exception e) {
+ throw new WeightLearningException("error during derivatives
calculation", e);
}
- triangle[l] =
triangle[l].add(resultingDeltaVector.outerProduct(activations[l]));
- nextLayerDelta = resultingDeltaVector;
+ count++;
}
- } catch (Exception e) {
- throw new WeightLearningException("error during phase 1 of
back-propagation algorithm", e);
- }
- count++;
+ return createDerivatives(triangle, count);
}
- for (int i = 0; i < triangle.length; i++) {
- // TODO : introduce regularization diversification on bias term
(currently not regularized)
- triangle[i] = triangle[i].scalarMultiply(1 / count);
+
+ private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet,
RealMatrix[] derivatives, double alpha) {
+ RealMatrix[] updatedParameters = new
RealMatrix[weightsMatrixSet.length];
+ for (int l = 0; l < weightsMatrixSet.length; l++) {
+ double[][] updatedWeights = weightsMatrixSet[l].getData();
+ for (int i = 0; i < updatedWeights.length; i++) {
+ for (int j = 0; j < updatedWeights[i].length; j++) {
+ updatedWeights[i][j] = updatedWeights[i][j] - alpha *
derivatives[l].getData()[i][j];
+ }
+ }
+ updatedParameters[l] = new Array2DRowRealMatrix(updatedWeights);
+ }
+ return updatedParameters;
}
- // TODO : now apply gradient descent (or other optimization/minimization
algorithms) with 'triangle' derivative terms and the cost function
+ private RealMatrix[] createDerivatives(RealMatrix[] triangle, double
count) {
+ RealMatrix[] derivatives = new RealMatrix[triangle.length];
+ for (int i = 0; i < triangle.length; i++) {
+ // TODO : introduce regularization diversification on bias term
(currently not regularized)
+ derivatives[i] = triangle[i].scalarMultiply(1d / count);
+ }
+ return derivatives;
+ }
- return null;
+ private void updateTriangle(RealMatrix[] triangle, RealVector[]
activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
+ for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
+ RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
+ if (triangle[l] == null) {
+ triangle[l] = realMatrix;
+ } else {
+ triangle[l] = triangle[l].add(realMatrix);
+ }
+ }
}
private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector
activationsVector, RealVector nextLayerDelta) {
@@ -115,7 +207,7 @@ public class BackPropagationLearningStra
}
RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); //
turn example output to a vector
- // TODO : improve error calculation > this could be er_a = out_a * (1 -
out_a) * (tgt_a - out_a)
+ // TODO : improve error calculation -> this could be er_a = out_a * (1 -
out_a) * (tgt_a - out_a)
return output.subtract(learnedOutputRealVector);
}
}
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
---
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
(original)
+++
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
Wed Jun 19 10:24:50 2013
@@ -19,6 +19,9 @@
package org.apache.yay.core;
+import java.util.ArrayList;
+import java.util.Collection;
+
import org.apache.commons.collections.CollectionUtils;
import org.apache.commons.collections.Transformer;
import org.apache.commons.math3.linear.ArrayRealVector;
@@ -28,9 +31,6 @@ import org.apache.yay.ActivationFunction
import org.apache.yay.PredictionStrategy;
import org.apache.yay.core.utils.ConversionUtils;
-import java.util.ArrayList;
-import java.util.Collection;
-
/**
* Octave code for FF to be converted :
* m = size(X, 1);
@@ -77,6 +77,7 @@ public class FeedForwardStrategy impleme
for (int j = 0; j < doubles.length; j++) {
row.add(j, doubles[j]);
}
+ // TODO : see if bias term is handled correctly here
CollectionUtils.transform(row, new ActivationRowTransformer());
double[] finRow = new double[row.size()];
for (int h = 0; h < finRow.length; h++) {
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
---
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
(original)
+++
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
Wed Jun 19 10:24:50 2013
@@ -31,12 +31,17 @@ import org.apache.yay.TrainingSet;
*/
public class LogisticRegressionCostFunction implements
NeuralNetworkCostFunction {
+ private static final double DEFAULT_LAMBDA = 0.1d;
private final Double lambda;
public LogisticRegressionCostFunction(Double lambda) {
this.lambda = lambda;
}
+ public LogisticRegressionCostFunction() {
+ this.lambda = DEFAULT_LAMBDA;
+ }
+
@Override
public Double calculateAggregatedCost(TrainingSet<Double, Double>
trainingExamples,
Hypothesis<RealMatrix, Double, Double>
hypothesis) throws Exception {
Modified:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
---
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
(original)
+++
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Wed Jun 19 10:24:50 2013
@@ -29,6 +29,7 @@ import org.apache.yay.TrainingSet;
import org.apache.yay.core.utils.ExamplesFactory;
import org.junit.Test;
+import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.assertNotNull;
/**
@@ -40,18 +41,44 @@ public class BackPropagationLearningStra
public void testLearningWithRandomSamples() throws Exception {
PredictionStrategy<Double, Double> predictionStrategy = new
FeedForwardStrategy(new SigmoidFunction());
BackPropagationLearningStrategy backPropagationLearningStrategy =
- new BackPropagationLearningStrategy(predictionStrategy, new
LogisticRegressionCostFunction(0.4d));
+ new BackPropagationLearningStrategy(0.1d, 0.0003d,
predictionStrategy, new LogisticRegressionCostFunction(0.5d));
- // 3 input units, 3 hidden units, 1 output unit
+ // 3 input units, 3 hidden units, 4 hidden units, 1 output unit
RealMatrix[] initialWeights = new RealMatrix[3];
initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d},
{1d, 0.6d, 3d}, {1d, 2d, 2d}, {1d, 0.6d, 3d}});
- initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d,
0d}, {0d, 0.5d, 1d, 0.5d}, {0d, 0.1d, 8d, 0.1d}, {0d, 0.1d, 8d, 0.2d}});
- initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{0.7d, 2d,
0.3d, 0.5d}});
+ initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d,
0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}});
+ initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d,
0.5d}});
Collection<TrainingExample<Double, Double>> samples = createSamples(100,
2);
TrainingSet<Double, Double> trainingSet = new TrainingSet<Double,
Double>(samples);
- RealMatrix[] weights =
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
- assertNotNull(weights);
+ RealMatrix[] learntWeights =
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+ assertNotNull(learntWeights);
+
+ assertFalse(learntWeights[0].equals(initialWeights[0]));
+ assertFalse(learntWeights[1].equals(initialWeights[1]));
+ assertFalse(learntWeights[2].equals(initialWeights[2]));
+ }
+
+ @Test
+ public void testLearningWithRandomSamplesAndRandomWeightsAndParams() throws
Exception {
+ PredictionStrategy<Double, Double> predictionStrategy = new
FeedForwardStrategy(new SigmoidFunction());
+ BackPropagationLearningStrategy backPropagationLearningStrategy =
+ new BackPropagationLearningStrategy(0.1d, 0.001d,
predictionStrategy, new LogisticRegressionCostFunction(Math.random()));
+
+ // 3 input units, 3 hidden units, 4 hidden units, 1 output unit
+ RealMatrix[] initialWeights = new RealMatrix[3];
+ initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d},
{1d, Math.random(), Math.random()}, {1d, Math.random(), Math.random()}, {1d,
Math.random(), Math.random()}});
+ initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d,
0d}, {1d, Math.random(), Math.random(), Math.random()}, {1d, Math.random(),
Math.random(), Math.random()}, {1d, Math.random(), Math.random(),
Math.random()}});
+ initialWeights[2] = new Array2DRowRealMatrix(new
double[][]{{1d,Math.random(), Math.random(), Math.random()}});
+
+ Collection<TrainingExample<Double, Double>> samples = createSamples(50, 2);
+ TrainingSet<Double, Double> trainingSet = new TrainingSet<Double,
Double>(samples);
+ RealMatrix[] learntWeights =
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+ assertNotNull(learntWeights);
+
+ assertFalse(learntWeights[0].equals(initialWeights[0]));
+ assertFalse(learntWeights[1].equals(initialWeights[1]));
+ assertFalse(learntWeights[2].equals(initialWeights[2]));
}
private Collection<TrainingExample<Double, Double>> createSamples(int size,
int noOfFeatures) {
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]