Author: tommaso
Date: Wed Jun 19 10:24:50 2013
New Revision: 1494537

URL: http://svn.apache.org/r1494537
Log:
first fully working sketch of backpropagation

Modified:
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 Wed Jun 19 10:24:50 2013
@@ -18,12 +18,15 @@
  */
 package org.apache.yay.core;
 
+import java.util.Arrays;
+
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.CostFunction;
 import org.apache.yay.LearningStrategy;
+import org.apache.yay.NeuralNetwork;
 import org.apache.yay.PredictionStrategy;
 import org.apache.yay.TrainingExample;
 import org.apache.yay.TrainingSet;
@@ -36,60 +39,149 @@ import org.apache.yay.core.utils.Convers
  */
 public class BackPropagationLearningStrategy implements 
LearningStrategy<Double, Double> {
 
+  private static final double DEFAULT_THRESHOLD = 0.005;
+  private static final int MAX_ITERATIONS = 100000;
+  private static final double DEFAULT_ALPHA = 0.03;
+
   private final PredictionStrategy<Double, Double> predictionStrategy;
-  private CostFunction<RealMatrix, Double, Double> costFunction;
+  private final CostFunction<RealMatrix, Double, Double> costFunction;
+  private final double alpha;
+  private final double threshold;
+
 
-  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double> 
predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) {
+  public BackPropagationLearningStrategy(double alpha, double threshold, 
PredictionStrategy<Double, Double> predictionStrategy, CostFunction<RealMatrix, 
Double, Double> costFunction) {
     this.predictionStrategy = predictionStrategy;
     this.costFunction = costFunction;
+    this.alpha = alpha;
+    this.threshold = threshold;
+  }
+
+  public BackPropagationLearningStrategy() {
+    // commonly used defaults
+    this.predictionStrategy = new FeedForwardStrategy(new SigmoidFunction());
+    this.costFunction = new LogisticRegressionCostFunction();
+    this.alpha = DEFAULT_ALPHA;
+    this.threshold = DEFAULT_THRESHOLD;
   }
 
   @Override
   public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, 
TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
-    // set up the accumulator matrix(es)
-    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+    RealMatrix[] updatedWeights = weightsMatrixSet;
+    try {
+      int iterations = 0;
+
+      NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, 
new VoidLearningStrategy<Double, Double>(), new FeedForwardStrategy(new 
SigmoidFunction()), new MaxSelectionFunction<Double>());
+
+      double cost = Double.MAX_VALUE;
+      while (true) {
+        // calculate cost
+        double newCost = 
costFunction.calculateAggregatedCost(trainingExamples, hypothesis);
+
+        if (newCost > cost) {
+          throw new RuntimeException("failed to converge at iteration " + 
iterations + " with alpha "+ alpha +" : cost going from " + cost + " to " + 
newCost);
+        } else if (cost == newCost || newCost < threshold || iterations > 
MAX_ITERATIONS) {
+          System.out.println("successfully converged after " + iterations + " 
iterations with cost " + newCost + " and parameters " + 
Arrays.toString(hypothesis.getParameters()));
+          break;
+        }
 
-    int count = 0;
-    for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
-      try {
-        // contains activation errors for the current training example
-        int noOfMatrixes = weightsMatrixSet.length - 1;
+        // update registered cost
+        cost = newCost;
 
-        // feed forward propagation
-        RealVector[] activations = 
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
 weightsMatrixSet);
+        // calculate the derivatives to update the parameters
+        RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, 
trainingExamples);
 
-        // calculate output error
-        RealVector error = calculateOutputError(trainingExample, activations);
+        // calculate the updated parameters
+        updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha);
 
-        RealVector nextLayerDelta = error;
+        // update parameters in the hypothesis
+        hypothesis.setParameters(updatedWeights);
 
-        triangle[noOfMatrixes] = new 
Array2DRowRealMatrix(weightsMatrixSet[noOfMatrixes].getColumnDimension(), 
weightsMatrixSet[noOfMatrixes].getRowDimension());
-        triangle[noOfMatrixes] = 
triangle[noOfMatrixes].add(activations[noOfMatrixes - 1].outerProduct(error));
+        iterations++;
+        }
+      }
+      catch (Exception e) {
+        throw new WeightLearningException("error during backprop learning", e);
+      }
 
-        // back prop the error and update the deltas accordingly
-        for (int l = weightsMatrixSet.length - 1; l > 0; l--) {
-          RealVector resultingDeltaVector = 
calculateDeltaVector(weightsMatrixSet[l], activations[l-1], nextLayerDelta);
+    return updatedWeights;
+  }
 
-          if (triangle[l] == null) {
-            triangle[l] = new 
Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), 
weightsMatrixSet[l].getColumnDimension());
+    private RealMatrix[] calculateDerivatives(RealMatrix[] weightsMatrixSet, 
TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
+        // set up the accumulator matrix(es)
+        RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+        RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
+
+        int noOfMatrixes = weightsMatrixSet.length - 1;
+        double count = 0;
+        for (TrainingExample<Double, Double> trainingExample : 
trainingExamples) {
+          try {
+            // get activations from feed forward propagation
+            RealVector[] activations = 
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
 weightsMatrixSet);
+
+            // calculate output error (corresponding to the last delta^l)
+            RealVector nextLayerDelta = calculateOutputError(trainingExample, 
activations);
+
+            deltaVectors[noOfMatrixes] = nextLayerDelta;
+
+            // back prop the error and update the deltas accordingly
+            for (int l = noOfMatrixes; l > 0; l--) {
+              RealVector currentActivationsVector = activations[l - 1];
+              nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], 
currentActivationsVector, nextLayerDelta);
+
+              // collect delta vectors for this example
+              deltaVectors[l - 1] = nextLayerDelta;
+            }
+
+            RealVector[] newActivations = new RealVector[activations.length];
+            newActivations[0] = 
ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
+            for (int k = 0; k < activations.length - 1; k++) {
+                newActivations[k+1] = activations[k];
+            }
+
+            // update triangle (big delta matrix)
+            updateTriangle(triangle, newActivations, deltaVectors, 
weightsMatrixSet);
+
+          } catch (Exception e) {
+            throw new WeightLearningException("error during derivatives 
calculation", e);
           }
-          triangle[l] = 
triangle[l].add(resultingDeltaVector.outerProduct(activations[l]));
-          nextLayerDelta = resultingDeltaVector;
+          count++;
         }
 
-      } catch (Exception e) {
-        throw new WeightLearningException("error during phase 1 of 
back-propagation algorithm", e);
-      }
-      count++;
+        return createDerivatives(triangle, count);
     }
-    for (int i = 0; i < triangle.length; i++) {
-      // TODO : introduce regularization diversification on bias term 
(currently not regularized)
-      triangle[i] = triangle[i].scalarMultiply(1 / count);
+
+    private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet, 
RealMatrix[] derivatives, double alpha) {
+        RealMatrix[] updatedParameters = new 
RealMatrix[weightsMatrixSet.length];
+        for (int l = 0; l < weightsMatrixSet.length; l++) {
+            double[][] updatedWeights = weightsMatrixSet[l].getData();
+            for (int i = 0; i < updatedWeights.length; i++) {
+                for (int j = 0; j < updatedWeights[i].length; j++) {
+                    updatedWeights[i][j] = updatedWeights[i][j] - alpha * 
derivatives[l].getData()[i][j];
+                }
+            }
+            updatedParameters[l] = new Array2DRowRealMatrix(updatedWeights);
+        }
+        return updatedParameters;
     }
 
-    // TODO : now apply gradient descent (or other optimization/minimization 
algorithms) with 'triangle' derivative terms and the cost function
+    private RealMatrix[] createDerivatives(RealMatrix[] triangle, double 
count) {
+        RealMatrix[] derivatives = new RealMatrix[triangle.length];
+        for (int i = 0; i < triangle.length; i++) {
+          // TODO : introduce regularization diversification on bias term 
(currently not regularized)
+          derivatives[i] = triangle[i].scalarMultiply(1d / count);
+        }
+        return derivatives;
+    }
 
-    return null;
+    private void updateTriangle(RealMatrix[] triangle, RealVector[] 
activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
+      for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
+          RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
+          if (triangle[l] == null) {
+              triangle[l] = realMatrix;
+          } else {
+              triangle[l] = triangle[l].add(realMatrix);
+          }
+      }
   }
 
   private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector 
activationsVector, RealVector nextLayerDelta) {
@@ -115,7 +207,7 @@ public class BackPropagationLearningStra
     }
     RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // 
turn example output to a vector
 
-    // TODO : improve error calculation > this could be er_a = out_a * (1 - 
out_a) * (tgt_a - out_a)
+    // TODO : improve error calculation -> this could be er_a = out_a * (1 - 
out_a) * (tgt_a - out_a)
     return output.subtract(learnedOutputRealVector);
   }
 }

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java 
(original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java 
Wed Jun 19 10:24:50 2013
@@ -19,6 +19,9 @@
 
 package org.apache.yay.core;
 
+import java.util.ArrayList;
+import java.util.Collection;
+
 import org.apache.commons.collections.CollectionUtils;
 import org.apache.commons.collections.Transformer;
 import org.apache.commons.math3.linear.ArrayRealVector;
@@ -28,9 +31,6 @@ import org.apache.yay.ActivationFunction
 import org.apache.yay.PredictionStrategy;
 import org.apache.yay.core.utils.ConversionUtils;
 
-import java.util.ArrayList;
-import java.util.Collection;
-
 /**
  * Octave code for FF to be converted :
  * m = size(X, 1);
@@ -77,6 +77,7 @@ public class FeedForwardStrategy impleme
         for (int j = 0; j < doubles.length; j++) {
           row.add(j, doubles[j]);
         }
+        // TODO : see if bias term is handled correctly here
         CollectionUtils.transform(row, new ActivationRowTransformer());
         double[] finRow = new double[row.size()];
         for (int h = 0; h < finRow.length; h++) {

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
 Wed Jun 19 10:24:50 2013
@@ -31,12 +31,17 @@ import org.apache.yay.TrainingSet;
  */
 public class LogisticRegressionCostFunction implements 
NeuralNetworkCostFunction {
 
+  private static final double DEFAULT_LAMBDA = 0.1d;
   private final Double lambda;
 
   public LogisticRegressionCostFunction(Double lambda) {
     this.lambda = lambda;
   }
 
+  public LogisticRegressionCostFunction() {
+    this.lambda = DEFAULT_LAMBDA;
+  }
+
   @Override
   public Double calculateAggregatedCost(TrainingSet<Double, Double> 
trainingExamples,
                               Hypothesis<RealMatrix, Double, Double> 
hypothesis) throws Exception {

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1494537&r1=1494536&r2=1494537&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 Wed Jun 19 10:24:50 2013
@@ -29,6 +29,7 @@ import org.apache.yay.TrainingSet;
 import org.apache.yay.core.utils.ExamplesFactory;
 import org.junit.Test;
 
+import static junit.framework.Assert.assertFalse;
 import static junit.framework.Assert.assertNotNull;
 
 /**
@@ -40,18 +41,44 @@ public class BackPropagationLearningStra
   public void testLearningWithRandomSamples() throws Exception {
     PredictionStrategy<Double, Double> predictionStrategy = new 
FeedForwardStrategy(new SigmoidFunction());
     BackPropagationLearningStrategy backPropagationLearningStrategy =
-            new BackPropagationLearningStrategy(predictionStrategy, new 
LogisticRegressionCostFunction(0.4d));
+            new BackPropagationLearningStrategy(0.1d, 0.0003d, 
predictionStrategy, new LogisticRegressionCostFunction(0.5d));
 
-    // 3 input units, 3 hidden units, 1 output unit
+    // 3 input units, 3 hidden units, 4 hidden units, 1 output unit
     RealMatrix[] initialWeights = new RealMatrix[3];
     initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d}, 
{1d, 0.6d, 3d}, {1d, 2d, 2d}, {1d, 0.6d, 3d}});
-    initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 
0d}, {0d, 0.5d, 1d, 0.5d}, {0d, 0.1d, 8d, 0.1d}, {0d, 0.1d, 8d, 0.2d}});
-    initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{0.7d, 2d, 
0.3d, 0.5d}});
+    initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 
0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}});
+    initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 
0.5d}});
 
     Collection<TrainingExample<Double, Double>> samples = createSamples(100, 
2);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, 
Double>(samples);
-    RealMatrix[] weights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
-    assertNotNull(weights);
+    RealMatrix[] learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+    assertNotNull(learntWeights);
+
+    assertFalse(learntWeights[0].equals(initialWeights[0]));
+    assertFalse(learntWeights[1].equals(initialWeights[1]));
+    assertFalse(learntWeights[2].equals(initialWeights[2]));
+  }
+
+  @Test
+  public void testLearningWithRandomSamplesAndRandomWeightsAndParams() throws 
Exception {
+    PredictionStrategy<Double, Double> predictionStrategy = new 
FeedForwardStrategy(new SigmoidFunction());
+    BackPropagationLearningStrategy backPropagationLearningStrategy =
+            new BackPropagationLearningStrategy(0.1d, 0.001d, 
predictionStrategy, new LogisticRegressionCostFunction(Math.random()));
+
+    // 3 input units, 3 hidden units, 4 hidden units, 1 output unit
+    RealMatrix[] initialWeights = new RealMatrix[3];
+    initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d}, 
{1d, Math.random(), Math.random()}, {1d, Math.random(), Math.random()}, {1d, 
Math.random(), Math.random()}});
+    initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 
0d}, {1d, Math.random(), Math.random(), Math.random()}, {1d, Math.random(), 
Math.random(), Math.random()}, {1d, Math.random(), Math.random(), 
Math.random()}});
+    initialWeights[2] = new Array2DRowRealMatrix(new 
double[][]{{1d,Math.random(), Math.random(), Math.random()}});
+
+    Collection<TrainingExample<Double, Double>> samples = createSamples(50, 2);
+    TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, 
Double>(samples);
+    RealMatrix[] learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+    assertNotNull(learntWeights);
+
+    assertFalse(learntWeights[0].equals(initialWeights[0]));
+    assertFalse(learntWeights[1].equals(initialWeights[1]));
+    assertFalse(learntWeights[2].equals(initialWeights[2]));
   }
 
   private Collection<TrainingExample<Double, Double>> createSamples(int size, 
int noOfFeatures) {



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to