Author: tommaso
Date: Wed Jun 26 20:09:51 2013
New Revision: 1497067
URL: http://svn.apache.org/r1497067
Log:
added backpropagation test with random network
Modified:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Modified:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1497067&r1=1497066&r2=1497067&view=diff
==============================================================================
---
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
(original)
+++
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
Wed Jun 26 20:09:51 2013
@@ -18,6 +18,10 @@
*/
package org.apache.yay.core;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Random;
+
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.yay.PredictionStrategy;
@@ -26,9 +30,6 @@ import org.apache.yay.TrainingSet;
import org.apache.yay.core.utils.ExamplesFactory;
import org.junit.Test;
-import java.util.ArrayList;
-import java.util.Collection;
-
import static junit.framework.Assert.assertFalse;
import static junit.framework.Assert.assertNotNull;
@@ -37,6 +38,70 @@ import static junit.framework.Assert.ass
*/
public class BackPropagationLearningStrategyTest {
+ @Test
+ public void testLearningWitgRandomNetwork() throws Exception {
+ BackPropagationLearningStrategy backPropagationLearningStrategy = new
BackPropagationLearningStrategy();
+
+ RealMatrix[] initialWeights = createRandomWeights();
+
+ Collection<TrainingExample<Double, Double>> samples = createSamples(1000,
initialWeights[0].getColumnDimension() - 1);
+ TrainingSet<Double, Double> trainingSet = new TrainingSet<Double,
Double>(samples);
+ RealMatrix[] learntWeights =
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+ assertNotNull(learntWeights);
+
+ for (int i = 0; i < learntWeights.length; i++) {
+ assertFalse("weights have not been changed at level " + i + 1,
learntWeights[i].equals(initialWeights[i]));
+ }
+ }
+
+ private RealMatrix[] createRandomWeights() {
+ Random r = new Random();
+ int weightsCount = (Math.abs(r.nextInt()) % 5) + 2;
+
+ RealMatrix[] initialWeights = new RealMatrix[weightsCount];
+ for (int i = 0; i < weightsCount; i++) {
+ int rows = (Math.abs(r.nextInt()) % 4) + 2;
+ int cols;
+ if (i == 0) {
+ cols = (Math.abs(r.nextInt()) % 4) + 2;
+ }
+ else {
+ cols = initialWeights[i - 1].getRowDimension();
+ if (i == weightsCount - 1) {
+ rows = 1;
+ }
+ }
+ double[][] d = new double[rows][cols];
+ for (int c = 0; c < cols; c++) {
+ if (i == weightsCount - 1) {
+ if (c == 0) {
+ d[0][c] = 1d;
+ }
+ else {
+ d[0][c] = r.nextDouble();
+ }
+ }
+ else {
+ d[0][c] = 0;
+ }
+ }
+
+ for (int k = 1; k < rows; k++) {
+ for (int j = 0; j < cols; j++) {
+ double val;
+ if (j == 0) {
+ val = 1d;
+ }
+ else {
+ val = r.nextDouble();
+ }
+ d[k][j] = val;
+ }
+ }
+ initialWeights[i] = new Array2DRowRealMatrix(d);
+ }
+ return initialWeights;
+}
@Test
public void testLearningWithDefaultSettingsAndRandomSamples() throws
Exception {
@@ -44,9 +109,9 @@ public class BackPropagationLearningStra
// 3 input units, 3 hidden units, 4 hidden units, 1 output unit
RealMatrix[] initialWeights = new RealMatrix[3];
- initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d},
{1d, 0.6d, 3d}, {1d, 2d, 2d}, {1d, 0.6d, 3d}});
- initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d,
0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}});
- initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d,
0.5d}});
+ initialWeights[0] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d},
{1d, 0.6d, 3d}, {1d, 2d, 2d}, {1d, 0.6d, 3d}}); // 4 x 3
+ initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d,
0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); // 4 x
4
+ initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d,
0.5d}}); // 1 x 4
Collection<TrainingExample<Double, Double>> samples = createSamples(1000,
2);
TrainingSet<Double, Double> trainingSet = new TrainingSet<Double,
Double>(samples);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]