Author: tommaso
Date: Wed Jun 26 20:59:08 2013
New Revision: 1497094

URL: http://svn.apache.org/r1497094
Log:
backprop testing improved

Modified:
    
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1497094&r1=1497093&r2=1497094&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 Wed Jun 26 20:59:08 2013
@@ -50,7 +50,25 @@ public class BackPropagationLearningStra
     assertNotNull(learntWeights);
 
     for (int i = 0; i < learntWeights.length; i++) {
-      assertFalse("weights have not been changed at level " + i + 1, 
learntWeights[i].equals(initialWeights[i]));
+      assertFalse("weights have not been changed", 
learntWeights[i].equals(initialWeights[i]));
+    }
+  }
+
+  @Test
+  public void testLearningWitgRandomNetworkAndRandomSettings() throws 
Exception {
+    BackPropagationLearningStrategy backPropagationLearningStrategy = new 
BackPropagationLearningStrategy(Math.random(),
+            Math.random(), new FeedForwardStrategy(Math.random() >= 0.5d ? new 
TanhFunction() : new SigmoidFunction()),
+            new LogisticRegressionCostFunction(Math.random()));
+
+    RealMatrix[] initialWeights = createRandomWeights();
+
+    Collection<TrainingExample<Double, Double>> samples = createSamples(1000, 
initialWeights[0].getColumnDimension() - 1);
+    TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, 
Double>(samples);
+    RealMatrix[] learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+    assertNotNull(learntWeights);
+
+    for (int i = 0; i < learntWeights.length; i++) {
+      assertFalse("weights have not been changed", 
learntWeights[i].equals(initialWeights[i]));
     }
   }
 
@@ -113,7 +131,7 @@ public class BackPropagationLearningStra
     initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 
0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); // 4 x 
4
     initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 
0.5d}}); // 1 x 4
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(1000, 
2);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 
2);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, 
Double>(samples);
     RealMatrix[] learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
@@ -135,7 +153,7 @@ public class BackPropagationLearningStra
     initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 
0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}});
     initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 
0.5d}});
 
-    Collection<TrainingExample<Double, Double>> samples = createSamples(100, 
2);
+    Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 
2);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, 
Double>(samples);
     RealMatrix[] learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);
@@ -157,7 +175,7 @@ public class BackPropagationLearningStra
     initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 
0d}, {1d, Math.random(), Math.random(), Math.random()}, {1d, Math.random(), 
Math.random(), Math.random()}, {1d, Math.random(), Math.random(), 
Math.random()}});
     initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 
Math.random(), Math.random(), Math.random()}});
 
-    Collection<TrainingExample<Double, Double>> samples = 
createSamples(500000, 2);
+    Collection<TrainingExample<Double, Double>> samples = 
createSamples(1000000, 2);
     TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, 
Double>(samples);
     RealMatrix[] learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
     assertNotNull(learntWeights);

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java?rev=1497094&r1=1497093&r2=1497094&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java 
(original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BasicPerceptronTest.java 
Wed Jun 26 20:59:08 2013
@@ -25,7 +25,6 @@ import java.util.LinkedList;
 import org.apache.yay.Feature;
 import org.apache.yay.TrainingExample;
 import org.apache.yay.TrainingSet;
-import org.apache.yay.core.BasicPerceptron;
 import org.junit.Before;
 import org.junit.Test;
 
@@ -40,12 +39,12 @@ public class BasicPerceptronTest {
 
   @Before
   public void setUp() throws Exception {
-      Collection<TrainingExample<Double, Double>> samples = new 
LinkedList<TrainingExample<Double, Double>>() ;
-      samples.add(createTrainingExample(1d, 4d, 5d, 6d));
-      samples.add(createTrainingExample(1d, 5d, 6d, 0.5d));
-      samples.add(createTrainingExample(0.1d, 9d, 4d, 1.9d));
-      samples.add(createTrainingExample(0.11d, 4d, 2.6d, 9.5d));
-      dataset = new TrainingSet<Double,Double>(samples);
+    Collection<TrainingExample<Double, Double>> samples = new 
LinkedList<TrainingExample<Double, Double>>() ;
+    samples.add(createTrainingExample(1d, 4d, 5d, 6d));
+    samples.add(createTrainingExample(1d, 5d, 6d, 0.5d));
+    samples.add(createTrainingExample(0.1d, 9d, 4d, 1.9d));
+    samples.add(createTrainingExample(0.11d, 4d, 2.6d, 9.5d));
+    dataset = new TrainingSet<Double,Double>(samples);
   }
 
   @Test



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to