Author: tommaso
Date: Wed Dec 12 12:26:01 2012
New Revision: 1420639

URL: http://svn.apache.org/viewvc?rev=1420639&view=rev
Log:
fixed PredictionStrategy#debugOutput and first phase of backprop

Modified:
    
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1420639&r1=1420638&r2=1420639&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
 Wed Dec 12 12:26:01 2012
@@ -43,14 +43,16 @@ public class BackPropagationLearningStra
   @Override
   public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet, 
Collection<TrainingExample<Double, Double[]>> trainingExamples) throws 
WeightLearningException {
     // set up the accumulator matrix(es)
-    RealMatrix[] deltas = new RealMatrix[weightsMatrixSet.length];
+    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
     for (TrainingExample<Double, Double[]> trainingExample : trainingExamples) 
{
       try {
         // contains activation errors for the current training example
+        // TODO : check if this should be RealVector[] < probably yes
         RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length 
- 1];
 
         // feed forward propagation
-        RealMatrix output = 
predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()),
 weightsMatrixSet);
+        RealMatrix[] activations = 
predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()),
 weightsMatrixSet);
+        RealMatrix output = activations[activations.length - 1];
         Double[] learnedOutput = trainingExample.getOutput(); // training 
example output
         RealVector predictedOutputVector = new 
ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn 
output to vector
         RealVector learnedOutputRealVector = new 
ArrayRealVector(learnedOutput); // turn example output to a vector
@@ -62,32 +64,34 @@ public class BackPropagationLearningStra
 
         // back prop the error and update the activationErrors accordingly
         // TODO : remove the bias term from the error calculations
-        for (int l = weightsMatrixSet.length - 2; l > 0; l--) {
-          WeightsMatrix currentMatrix = weightsMatrixSet[l];
-          ArrayRealVector realVector = new 
ArrayRealVector(output.getColumn(l)); // get l-th nn layer activations
-          ArrayRealVector identity = new 
ArrayRealVector(realVector.getDimension(), 1d);
-          RealVector gz = 
realVector.ebeMultiply(identity.subtract(realVector)); // = a^l .* (1-a^l)
-          RealVector resultingDeltaVector = 
currentMatrix.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
+        for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
+          WeightsMatrix thetaL = weightsMatrixSet[l];
+          ArrayRealVector activationsVector = new 
ArrayRealVector(activations[l].getRowVector(0)); // get l-th nn layer 
activations
+          ArrayRealVector identity = new 
ArrayRealVector(activationsVector.getDimension(), 1d);
+          RealVector gz = 
activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l 
.* (1-a^l)
+          RealVector resultingDeltaVector = 
thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
           if (activationErrors[l] == null) {
             activationErrors[l] = new Array2DRowRealMatrix(new 
ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
           }
           activationErrors[l] = new 
Array2DRowRealMatrix(resultingDeltaVector.toArray());
+          nextLayerDelta = resultingDeltaVector;
         }
 
         // update the accumulator matrix
-        for (int l = 0; l < deltas.length - 1; l++) {
-          if (deltas[l] == null) {
-            deltas[l] = new 
Array2DRowRealMatrix(weightsMatrixSet[l].getColumnDimension(), 
weightsMatrixSet[l].getRowDimension());
+        for (int l = 0; l < triangle.length - 1; l++) {
+          if (triangle[l] == null) {
+            triangle[l] = new 
Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(), 
weightsMatrixSet[l].getColumnDimension());
           }
-          deltas[l] = deltas[l].add(deltas[l + 
1]).multiply(weightsMatrixSet[l].transpose());
+          triangle[l] = 
triangle[l].add(activationErrors[l+1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
         }
 
       } catch (Exception e) {
         throw new WeightLearningException("error during phase 1 of 
back-propagation algorithm", e);
       }
     }
-    for (int i = 0; i < deltas.length; i++) {
-      deltas[i] = deltas[i].scalarMultiply(1 / trainingExamples.size());
+    for (int i = 0; i < triangle.length; i++) {
+      // TODO : introduce regularization diversification on bias term 
(currently not regularized)
+      triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
     }
 
     // TODO : now apply gradient descent (or other optimization/minimization 
algorithms) with this derivative terms and the cost function

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java?rev=1420639&r1=1420638&r2=1420639&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java 
Wed Dec 12 12:26:01 2012
@@ -49,23 +49,26 @@ public class FeedForwardStrategy impleme
     this.hypothesis = hypothesis;
   }
 
-
   @Override
   public Double predictOutput(Vector<Double> input, WeightsMatrix[] 
weightsMatrixSet) {
-    RealMatrix x = applyFF(input, weightsMatrixSet);
+    RealMatrix[] realMatrixes = applyFF(input, weightsMatrixSet);
+    RealMatrix x = realMatrixes[realMatrixes.length - 1];
     double[] lastColumn = x.getColumn(x.getColumnDimension() - 1);
     return Collections.max(Arrays.asList(ArrayUtils.toObject(lastColumn)));
   }
 
-  public RealMatrix debugOutput(Vector<Double> input, WeightsMatrix[] 
weightsMatrixSet) {
+  public RealMatrix[] debugOutput(Vector<Double> input, WeightsMatrix[] 
weightsMatrixSet) {
     return applyFF(input, weightsMatrixSet);
   }
 
-  private RealMatrix applyFF(Vector<Double> input, WeightsMatrix[] 
weightsMatrixSet) {
+  private RealMatrix[] applyFF(Vector<Double> input, WeightsMatrix[] 
weightsMatrixSet) {
+    RealMatrix[] debugOutput = new RealMatrix[weightsMatrixSet.length];
+
     // TODO : fix this impl as it's very slow
     RealVector v = ConversionUtils.toRealVector(input);
     RealMatrix x = v.outerProduct(new ArrayRealVector(new 
Double[]{1d})).transpose(); // a 1xN matrix
-    for (WeightsMatrix weightsMatrix : weightsMatrixSet) {
+    for (int w = 0; w < weightsMatrixSet.length; w++) {
+      WeightsMatrix weightsMatrix = weightsMatrixSet[w];
       // compute matrix multiplication
       x = x.multiply(weightsMatrix.transpose());
 
@@ -83,8 +86,9 @@ public class FeedForwardStrategy impleme
         }
         x.setRow(i, finRow);
       }
+      debugOutput[w] = x;
     }
-    return x;
+    return debugOutput;
   }
 
   private class HypothesisRowTransformer implements Transformer {

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1420639&r1=1420638&r2=1420639&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java 
Wed Dec 12 12:26:01 2012
@@ -29,6 +29,6 @@ public interface PredictionStrategy<I, O
 
   public O predictOutput(Vector<I> input, WeightsMatrix[] weightsMatrixSet);
 
-  public RealMatrix debugOutput(Vector<Double> input, WeightsMatrix[] 
weightsMatrixSet);
+  public RealMatrix[] debugOutput(Vector<Double> input, WeightsMatrix[] 
weightsMatrixSet);
 
 }



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to