Author: tommaso
Date: Wed Aug 1 21:43:54 2012
New Revision: 1368275
URL: http://svn.apache.org/viewvc?rev=1368275&view=rev
Log:
migrating to commons-math 3.0, continuing working on back prop, plus minor fixes
Modified:
labs/yay/trunk/core/pom.xml
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/WeightsMatrix.java
labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ConversionUtils.java
Modified: labs/yay/trunk/core/pom.xml
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1368275&r1=1368274&r2=1368275&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Wed Aug 1 21:43:54 2012
@@ -19,8 +19,8 @@
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
- <artifactId>commons-math</artifactId>
- <version>2.2</version>
+ <artifactId>commons-math3</artifactId>
+ <version>3.0</version>
</dependency>
<dependency>
<groupId>org.apache.commons</groupId>
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java?rev=1368275&r1=1368274&r2=1368275&view=diff
==============================================================================
---
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
(original)
+++
labs/yay/trunk/core/src/main/java/org/apache/yay/BackPropagationLearningStrategy.java
Wed Aug 1 21:43:54 2012
@@ -20,9 +20,9 @@ package org.apache.yay;
import java.util.Collection;
-import org.apache.commons.math.linear.ArrayRealVector;
-import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
import org.apache.yay.utils.ConversionUtils;
/**
@@ -41,6 +41,9 @@ public class BackPropagationLearningStra
public WeightsMatrix[] learnWeights(WeightsMatrix[] weightsMatrixSet,
Collection<TrainingExample<Double, double[]>> trainingExamples) throws
WeightLearningException {
for (TrainingExample<Double, double[]> trainingExample : trainingExamples)
{
try {
+
+ RealMatrix[] deltas = new RealMatrix[weightsMatrixSet.length - 1];
+
RealMatrix output =
predictionStrategy.debugOutput(ConversionUtils.toVector(trainingExample.getFeatureVector()),
weightsMatrixSet);
double[] learnedOutput = trainingExample.getOutput();
RealVector predictedOutputVector = new
ArrayRealVector(output.getColumn(output.getColumnDimension() - 1));
@@ -54,7 +57,8 @@ public class BackPropagationLearningStra
ArrayRealVector realVector = new
ArrayRealVector(output.getColumn(i));
ArrayRealVector identity = new
ArrayRealVector(realVector.getDimension(), 1d);
RealVector gz =
realVector.ebeMultiply(identity.subtract(realVector)); // = a^i .* (1-a^i)
- RealVector resultingLambdaVector =
currentMatrix.transpose().preMultiply(error).ebeMultiply(gz);
+ RealVector resultingDeltaVector =
currentMatrix.transpose().preMultiply(error).ebeMultiply(gz);
+// deltas[i] = deltas[i].add()
}
} catch (Exception e) {
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java?rev=1368275&r1=1368274&r2=1368275&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/FeedForwardStrategy.java
Wed Aug 1 21:43:54 2012
@@ -19,19 +19,19 @@
package org.apache.yay;
-import org.apache.commons.collections.CollectionUtils;
-import org.apache.commons.collections.Transformer;
-import org.apache.commons.lang3.ArrayUtils;
-import org.apache.commons.math.linear.ArrayRealVector;
-import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
-import org.apache.yay.utils.ConversionUtils;
-
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Vector;
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.Transformer;
+import org.apache.commons.lang3.ArrayUtils;
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.yay.utils.ConversionUtils;
+
/**
* Octave code for FF to be converted :
* m = size(X, 1);
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java?rev=1368275&r1=1368274&r2=1368275&view=diff
==============================================================================
---
labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
(original)
+++
labs/yay/trunk/core/src/main/java/org/apache/yay/LogisticRegressionCostFunction.java
Wed Aug 1 21:43:54 2012
@@ -18,7 +18,9 @@
*/
package org.apache.yay;
+import java.util.Arrays;
import java.util.Collection;
+import java.util.Collections;
/**
* This calculates the logistic regression cost function for neural networks
@@ -60,9 +62,10 @@ public class LogisticRegressionCostFunct
Collection<TrainingExample<Double,
Double>> trainingExamples) throws PredictionException, CreationException {
Double res = 0d;
NeuralNetwork<Double, Double> neuralNetwork =
NeuralNetworkFactory.create(
- (Collection<TrainingExample<Double, Double>>) trainingExamples,
- parameters, new VoidLearningStrategy<Double, Double>(), new
FeedForwardStrategy(
- (ActivationFunction<Double>) hypothesis));
+ trainingExamples,parameters, new VoidLearningStrategy<Double,
Double>(),
+ new FeedForwardStrategy(hypothesis));
+
+
for (TrainingExample<Double, Double> input : trainingExamples) {
// TODO : handle this for multiple outputs (multi class
classification)
Double predictedOutput = neuralNetwork.predict(input);
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1368275&r1=1368274&r2=1368275&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/PredictionStrategy.java
Wed Aug 1 21:43:54 2012
@@ -20,7 +20,7 @@ package org.apache.yay;
import java.util.Vector;
-import org.apache.commons.math.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
/**
* A {@link PredictionStrategy} defines an algorithm for the prediction of an
output of type O given an input of type I
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/WeightsMatrix.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/WeightsMatrix.java?rev=1368275&r1=1368274&r2=1368275&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/WeightsMatrix.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/WeightsMatrix.java Wed Aug
1 21:43:54 2012
@@ -18,7 +18,7 @@
*/
package org.apache.yay;
-import org.apache.commons.math.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
/**
* A matrix representing the weights applied to links between elaboration
units of different adjacent {@link Layer}s in
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ConversionUtils.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ConversionUtils.java?rev=1368275&r1=1368274&r2=1368275&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ConversionUtils.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/utils/ConversionUtils.java
Wed Aug 1 21:43:54 2012
@@ -18,21 +18,26 @@
*/
package org.apache.yay.utils;
-import org.apache.commons.math.linear.Array2DRowRealMatrix;
-import org.apache.commons.math.linear.OpenMapRealVector;
-import org.apache.commons.math.linear.RealMatrix;
-import org.apache.commons.math.linear.RealVector;
-import org.apache.yay.Example;
-import org.apache.yay.Feature;
-
import java.util.Collection;
import java.util.Vector;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.OpenMapRealVector;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.yay.Example;
+import org.apache.yay.Feature;
+
/**
* Temporary class for conversion between model objects and commons-math
matrices/vectors
*/
public class ConversionUtils {
+ /**
+ * Converts a set of examples to a matrix of inputs with features
+ * @param trainingSet samples with features of type Double
+ * @return a real matrix
+ */
public static RealMatrix toMatrix(Collection<Example<Double>> trainingSet) {
double[][] matrixData = new double[trainingSet.size()][];
@@ -46,6 +51,11 @@ public class ConversionUtils {
return new Array2DRowRealMatrix(matrixData);
}
+ /**
+ * converts an example with Double features to a double array
+ * @param sample the sample to convert
+ * @return a double array
+ */
private static double[] toDoubleArray(Example<Double> sample) {
double[] ar = new double[sample.getFeatureVector().size()];
int i = 0;
@@ -56,10 +66,20 @@ public class ConversionUtils {
return ar;
}
+ /**
+ * converts a vector of doubles to a real vector
+ * @param input a vector of Double objects
+ * @return a real vector
+ */
public static RealVector toRealVector(Vector<Double> input) {
return new OpenMapRealVector(input.toArray(new Double[input.size()]));
}
+ /**
+ * turns a vector of features of type Double into a vector of Doubles
+ * @param featureVector the vector of features
+ * @return a vector of Doubles
+ */
public static Vector<Double> toVector(Vector<Feature<Double>> featureVector)
{
// TODO : remove this and change APIs in a way that doesn't force to go
through this ugly loop
Vector<Double> resultVector = new Vector<Double>(featureVector.size());
@@ -69,4 +89,16 @@ public class ConversionUtils {
return resultVector;
}
+ /**
+ * this is just nice! :-) (thanks commons-math)
+ * @param ar a double array
+ * @return a Double array
+ */
+ public static Double[] toDoubleArray(double[] ar) {
+ Double[] doubles = new Double[ar.length];
+ for (int i = 0; i < ar.length; i++) {
+ doubles[i] = Double.valueOf(ar[i]);
+ }
+ return doubles;
+ }
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]