Author: tommaso
Date: Wed Dec 12 12:26:01 2012
New Revision: 1420639
URL: http://svn.apache.org/viewvc?rev=1420639&view=rev
Log:
fixed PredictionStrategy#debugOutput and first phase of backprop
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1420639&r1=1420638&r2=1420639&view=diff
==============================================================================
---
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
(original)
+++
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
Wed Dec 12 12:26:01 2012
@@ -43,14 +43,16 @@ public class BackPropagationLearningStra
@Override
public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet,
Collection<TrainingExample<Double, Double[]>> trainingExamples) throws
WeightLearningException {
// set up the accumulator matrix(es)
- RealMatrix[] deltas = new RealMatrix[weightsMatrixSet.length];
+ RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
for (TrainingExample<Double, Double[]> trainingExample : trainingExamples)
{
try {
// contains activation errors for the current training example
+ // TODO : check if this should be RealVector[] < probably yes
RealMatrix[] activationErrors = new RealMatrix[weightsMatrixSet.length
- 1];
// feed forward propagation
- RealMatrix output =
predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()),
weightsMatrixSet);
+ RealMatrix[] activations =
predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()),
weightsMatrixSet);
+ RealMatrix output = activations[activations.length - 1];
Double[] learnedOutput = trainingExample.getOutput(); // training
example output
RealVector predictedOutputVector = new
ArrayRealVector(output.getColumn(output.getColumnDimension() - 1)); // turn
output to vector
RealVector learnedOutputRealVector = new
ArrayRealVector(learnedOutput); // turn example output to a vector
@@ -62,32 +64,34 @@ public class BackPropagationLearningStra
// back prop the error and update the activationErrors accordingly
// TODO : remove the bias term from the error calculations
- for (int l = weightsMatrixSet.length - 2; l > 0; l--) {
- WeightsMatrix currentMatrix = weightsMatrixSet[l];
- ArrayRealVector realVector = new
ArrayRealVector(output.getColumn(l)); // get l-th nn layer activations
- ArrayRealVector identity = new
ArrayRealVector(realVector.getDimension(), 1d);
- RealVector gz =
realVector.ebeMultiply(identity.subtract(realVector)); // = a^l .* (1-a^l)
- RealVector resultingDeltaVector =
currentMatrix.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
+ for (int l = weightsMatrixSet.length - 2; l >= 0; l--) {
+ WeightsMatrix thetaL = weightsMatrixSet[l];
+ ArrayRealVector activationsVector = new
ArrayRealVector(activations[l].getRowVector(0)); // get l-th nn layer
activations
+ ArrayRealVector identity = new
ArrayRealVector(activationsVector.getDimension(), 1d);
+ RealVector gz =
activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l
.* (1-a^l)
+ RealVector resultingDeltaVector =
thetaL.transpose().preMultiply(nextLayerDelta).ebeMultiply(gz);
if (activationErrors[l] == null) {
activationErrors[l] = new Array2DRowRealMatrix(new
ArrayRealVector(resultingDeltaVector.getDimension(), 1d).toArray());
}
activationErrors[l] = new
Array2DRowRealMatrix(resultingDeltaVector.toArray());
+ nextLayerDelta = resultingDeltaVector;
}
// update the accumulator matrix
- for (int l = 0; l < deltas.length - 1; l++) {
- if (deltas[l] == null) {
- deltas[l] = new
Array2DRowRealMatrix(weightsMatrixSet[l].getColumnDimension(),
weightsMatrixSet[l].getRowDimension());
+ for (int l = 0; l < triangle.length - 1; l++) {
+ if (triangle[l] == null) {
+ triangle[l] = new
Array2DRowRealMatrix(weightsMatrixSet[l].getRowDimension(),
weightsMatrixSet[l].getColumnDimension());
}
- deltas[l] = deltas[l].add(deltas[l +
1]).multiply(weightsMatrixSet[l].transpose());
+ triangle[l] =
triangle[l].add(activationErrors[l+1].getRowVector(0).outerProduct(activations[l].getRowVector(0)));
}
} catch (Exception e) {
throw new WeightLearningException("error during phase 1 of
back-propagation algorithm", e);
}
}
- for (int i = 0; i < deltas.length; i++) {
- deltas[i] = deltas[i].scalarMultiply(1 / trainingExamples.size());
+ for (int i = 0; i < triangle.length; i++) {
+ // TODO : introduce regularization diversification on bias term
(currently not regularized)
+ triangle[i] = triangle[i].scalarMultiply(1 / trainingExamples.size());
}
// TODO : now apply gradient descent (or other optimization/minimization
algorithms) with this derivative terms and the cost function
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java?rev=1420639&r1=1420638&r2=1420639&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
Wed Dec 12 12:26:01 2012
@@ -49,23 +49,26 @@ public class FeedForwardStrategy impleme
this.hypothesis = hypothesis;
}
-
@Override
public Double predictOutput(Vector<Double> input, WeightsMatrix[]
weightsMatrixSet) {
- RealMatrix x = applyFF(input, weightsMatrixSet);
+ RealMatrix[] realMatrixes = applyFF(input, weightsMatrixSet);
+ RealMatrix x = realMatrixes[realMatrixes.length - 1];
double[] lastColumn = x.getColumn(x.getColumnDimension() - 1);
return Collections.max(Arrays.asList(ArrayUtils.toObject(lastColumn)));
}
- public RealMatrix debugOutput(Vector<Double> input, WeightsMatrix[]
weightsMatrixSet) {
+ public RealMatrix[] debugOutput(Vector<Double> input, WeightsMatrix[]
weightsMatrixSet) {
return applyFF(input, weightsMatrixSet);
}
- private RealMatrix applyFF(Vector<Double> input, WeightsMatrix[]
weightsMatrixSet) {
+ private RealMatrix[] applyFF(Vector<Double> input, WeightsMatrix[]
weightsMatrixSet) {
+ RealMatrix[] debugOutput = new RealMatrix[weightsMatrixSet.length];
+
// TODO : fix this impl as it's very slow
RealVector v = ConversionUtils.toRealVector(input);
RealMatrix x = v.outerProduct(new ArrayRealVector(new
Double[]{1d})).transpose(); // a 1xN matrix
- for (WeightsMatrix weightsMatrix : weightsMatrixSet) {
+ for (int w = 0; w < weightsMatrixSet.length; w++) {
+ WeightsMatrix weightsMatrix = weightsMatrixSet[w];
// compute matrix multiplication
x = x.multiply(weightsMatrix.transpose());
@@ -83,8 +86,9 @@ public class FeedForwardStrategy impleme
}
x.setRow(i, finRow);
}
+ debugOutput[w] = x;
}
- return x;
+ return debugOutput;
}
private class HypothesisRowTransformer implements Transformer {
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1420639&r1=1420638&r2=1420639&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
Wed Dec 12 12:26:01 2012
@@ -29,6 +29,6 @@ public interface PredictionStrategy<I, O
public O predictOutput(Vector<I> input, WeightsMatrix[] weightsMatrixSet);
- public RealMatrix debugOutput(Vector<Double> input, WeightsMatrix[]
weightsMatrixSet);
+ public RealMatrix[] debugOutput(Vector<Double> input, WeightsMatrix[]
weightsMatrixSet);
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]