Author: tommaso
Date: Mon Feb 25 07:48:17 2013
New Revision: 1449612

URL: http://svn.apache.org/r1449612
Log:
started refactoring backprop

Modified:
    
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1449612&r1=1449611&r2=1449612&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
 Mon Feb 25 07:48:17 2013
@@ -32,70 +32,81 @@ import java.util.Collection;
  */
 public class BackPropagationLearningStrategy implements 
LearningStrategy<Double, Double[]> {
 
-  private final PredictionStrategy<Double, Double[]> predictionStrategy;
-  private CostFunction<RealMatrix, Double> costFunction;
+    private final PredictionStrategy<Double, Double[]> predictionStrategy;
+    private CostFunction<RealMatrix, Double> costFunction;
 
-  public BackPropagationLearningStrategy(PredictionStrategy<Double, Double[]> 
predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
-    this.predictionStrategy = predictionStrategy;
-    this.costFunction = costFunction;
-  }
-
-  @Override
-  public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, 
Collection<TrainingExample<Double, Double[]>> trainingExamples) throws 
WeightLearningException {
-    // set up the accumulator matrix(es)
-    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
-    for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) 
{
-      try {
-        // contains activation errors for the current training example
-        // TODO : check if this should be RealVector[] < probably yes
-        RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length 
- 1];
-
-        // feed forward propagation
-        RealMatrix[] activations = 
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
 weightsMatrixSet);
-        RealMatrix output = activations[activations.length - 1];
-        Double[] learnedOutput = trainingExample.getOutput(); // training 
example output
-        RealVector predictedOutputVector = new 
ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn 
output to vector
-        RealVector learnedOutputRealVector = new 
ArrayRealVector(learnedOutput); // turn example output to a vector
-
-        RealVector error = 
predictedOutputVector.subtract(learnedOutputRealVector); // final layer error 
vector
-        activationErrors[activationErrors.length - 1] = new 
Array2DRowRealMatrix(error.toArray());
-
-        RealVector nextLayerDelta = new ArrayRealVector(error);
+    public BackPropagationLearningStrategy(PredictionStrategy<Double, 
Double[]> predictionStrategy, CostFunction<RealMatrix, Double> costFunction) {
+        this.predictionStrategy = predictionStrategy;
+        this.costFunction = costFunction;
+    }
 
-        // back prop the error and update the activationErrors accordingly
-        // TODO : remove the bias term from the error calculations
-        for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
-          RealMatrix thetaL = weightsMatrixSet[l];
-          ArrayRealVector activationsVector = new 
ArrayRealVector(activations[l].getRowVector(0)); // get l-th nn layer 
activations
-          ArrayRealVector identity = new 
ArrayRealVector(activationsVector.getDimension(), 1d);
-          RealVector gz = 
activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l 
.* (1-a^l)
-          RealVector resultingDeltaVector = 
thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
-          if (activationErrors[l] == null) {
-            activationErrors[l] = new Array2DRowRealMatrix(new 
ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
-          }
-          activationErrors[l] = new 
Array2DRowRealMatrix(resultingDeltaVector.toArray());
-          nextLayerDelta = resultingDeltaVector;
+    @Override
+    public RealMatrix[] learnWeights(RealMatrix[] weightsMatrixSet, 
Collection<TrainingExample<Double, Double[]>> trainingExamples) throws 
WeightLearningException {
+        // set up the accumulator matrix(es)
+        RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+        for (TrainingExample<Double, Double[]> trainingExample : 
trainingExamples) {
+            try {
+                // contains activation errors for the current training example
+                // TODO : check if this should be RealVector[] < probably yes
+                RealMatrix[] activationErrors = new 
RealMatrix[weightsMatrixSet.length - 1];
+
+                // feed forward propagation
+                RealMatrix[] activations = 
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
 weightsMatrixSet);
+
+                // calculate output error
+                RealVector error = calculateOutputError(trainingExample, 
activations);
+
+                activationErrors[activationErrors.length - 1] = new 
Array2DRowRealMatrix(error.toArray());
+
+                RealVector nextLayerDelta = new ArrayRealVector(error);
+
+                // back prop the error and update the activationErrors 
accordingly
+                // TODO : eventually remove the bias term from the error 
calculations
+                for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
+                    RealVector resultingDeltaVector = 
calculateDeltaVector(weightsMatrixSet[l], activations[l], nextLayerDelta);
+                    if (activationErrors[l] == null) {
+                        activationErrors[l] = new Array2DRowRealMatrix(new 
ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
+                    }
+                    activationErrors[l] = new 
Array2DRowRealMatrix(resultingDeltaVector.toArray());
+                    nextLayerDelta = resultingDeltaVector;
+                }
+
+                // update the accumulator matrix
+                for (int l = 0; l < triangle.length - 1; l++) {
+                    if (triangle[l] == null) {
+                        triangle[l] = new 
Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), 
weightsMatrixSet[l].getColumnDimension());
+                    }
+                    triangle[l] = triangle[l].add(activationErrors[l + 
1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+                }
+
+            } catch (Exception e) {
+                throw new WeightLearningException("error during phase 1 of 
back-propagation algorithm", e);
+            }
         }
-
-        // update the accumulator matrix
-        for (int l = 0; l < triangle.length - 1; l++) {
-          if (triangle[l] == null) {
-            triangle[l] = new 
Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), 
weightsMatrixSet[l].getColumnDimension());
-          }
-          triangle[l] = 
triangle[l].add(activationErrors[l+1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
+        for (int i = 0; i < triangle.length; i++) {
+            // TODO : introduce regularization diversification on bias term 
(currently not regularized)
+            triangle[i] = triangle[i].scalarMultiply(1 / 
trainingExamples.size());
         }
 
-      } catch (Exception e) {
-        throw new WeightLearningException("error during phase 1 of 
back-propagation algorithm", e);
-      }
+        // TODO : now apply gradient descent (or other 
optimization/minimization algorithms) with this derivative terms and the cost 
function
+
+        return null;
     }
-    for (int i = 0; i < triangle.length; i++) {
-      // TODO : introduce regularization diversification on bias term 
(currently not regularized)
-      triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
+
+    private RealVector calculateDeltaVector(RealMatrix thetaL, RealMatrix 
activation, RealVector nextLayerDelta) {
+        ArrayRealVector activationsVector = new 
ArrayRealVector(activation.getRowVector(0)); // get l-th nn layer activations
+        ArrayRealVector identity = new 
ArrayRealVector(activationsVector.getDimension(), 1d);
+        RealVector gz = 
activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l 
.* (1-a^l)
+        return thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
     }
 
-    // TODO : now apply gradient descent (or other optimization/minimization 
algorithms) with this derivative terms and the cost function
+    private RealVector calculateOutputError(TrainingExample<Double, Double[]> 
trainingExample, RealMatrix[] activations) {
+        RealMatrix output = activations[activations.length - 1];
+        Double[] learnedOutput = trainingExample.getOutput(); // training 
example output
+        RealVector predictedOutputVector = new 
ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn 
output to vector
+        RealVector learnedOutputRealVector = new 
ArrayRealVector(learnedOutput); // turn example output to a vector
 
-    return null;
-  }
+        // TODO : improve error calculation > this should be er_a = out_a * (1 
- out_a) * (tgt_a - out_a)
+        return predictedOutputVector.subtract(learnedOutputRealVector);
+    }
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to