Author: tommaso
Date: Mon Sep 21 16:24:07 2015
New Revision: 1704348

URL: http://svn.apache.org/viewvc?rev=1704348&view=rev
Log:
switch from batch to stochastic GD in backprop

Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
    labs/yay/trunk/pom.xml

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java?rev=1704348&r1=1704347&r2=1704348&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java (original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java Mon Sep 
21 16:24:07 2015
@@ -39,14 +39,13 @@ public interface CostFunction<T, I, O> {
                                  Hypothesis<T, I, O> hypothesis) throws 
Exception;
 
   /**
-   * Calculate the cost of a single {@link org.apache.yay.TrainingExample} for 
a given {@link org.apache.yay.Hypothesis}
+   * Calculate the cost of one or more {@link org.apache.yay.TrainingExample}s 
for a given {@link org.apache.yay.Hypothesis}
    *
-   * @param trainingExample the training example
    * @param hypothesis      the hypothesis
+   * @param trainingExamples some training examples
    * @return a <code>Double</code> cost
    * @throws Exception if any error occurs during the cost calculation
    */
-  Double calculateCost(TrainingExample<I, O> trainingExample,
-                       Hypothesis<T, I, O> hypothesis) throws Exception;
+  Double calculateCost(Hypothesis<T, I, O> hypothesis, TrainingExample<I, 
O>... trainingExamples) throws Exception;
 
 }

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1704348&r1=1704347&r2=1704348&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 Mon Sep 21 16:24:07 2015
@@ -19,6 +19,8 @@
 package org.apache.yay.core;
 
 import java.util.Arrays;
+import java.util.Iterator;
+
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.commons.math3.linear.ArrayRealVector;
 import org.apache.commons.math3.linear.RealMatrix;
@@ -46,13 +48,21 @@ public class BackPropagationLearningStra
   private final CostFunction<RealMatrix, Double, Double> costFunction;
   private final double alpha;
   private final double threshold;
+  private final int batch;
+
 
+  public BackPropagationLearningStrategy(double alpha, double threshold, 
PredictionStrategy<Double, Double> predictionStrategy,
+                                         CostFunction<RealMatrix, Double, 
Double> costFunction) {
+    this(alpha, 1, threshold, predictionStrategy, costFunction);
+  }
 
-  public BackPropagationLearningStrategy(double alpha, double threshold, 
PredictionStrategy<Double, Double> predictionStrategy, CostFunction<RealMatrix, 
Double, Double> costFunction) {
+  public BackPropagationLearningStrategy(double alpha, int batch, double 
threshold, PredictionStrategy<Double, Double> predictionStrategy,
+                                         CostFunction<RealMatrix, Double, 
Double> costFunction) {
     this.predictionStrategy = predictionStrategy;
     this.costFunction = costFunction;
     this.alpha = alpha;
     this.threshold = threshold;
+    this.batch = batch;
   }
 
   public BackPropagationLearningStrategy() {
@@ -61,6 +71,7 @@ public class BackPropagationLearningStra
     this.costFunction = new LogisticRegressionCostFunction();
     this.alpha = DEFAULT_ALPHA;
     this.threshold = DEFAULT_THRESHOLD;
+    this.batch = 1;
   }
 
   @Override
@@ -70,13 +81,29 @@ public class BackPropagationLearningStra
       int iterations = 0;
 
       NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, 
new VoidLearningStrategy<Double, Double>(), predictionStrategy, new 
MaxSelectionFunction<Double>());
+      Iterator<TrainingExample<Double, Double>> iterator = 
trainingExamples.iterator();
 
       double cost = Double.MAX_VALUE;
       while (true) {
+
+        TrainingSet<Double, Double> samples;
+        if (batch == -1) {
+          samples = trainingExamples;
+        } else {
+          TrainingExample<Double, Double>[] miniBatch = new 
TrainingExample[batch];
+          for (int i = 0; i < batch; i++) {
+            if (!iterator.hasNext()) {
+              iterator = trainingExamples.iterator();
+            }
+            miniBatch[i] = iterator.next();
+          }
+          samples = new TrainingSet<>(Arrays.asList(miniBatch));
+        }
+
         // calculate cost
-        double newCost = 
costFunction.calculateAggregatedCost(trainingExamples, hypothesis);
+        double newCost = costFunction.calculateAggregatedCost(samples, 
hypothesis);
 
-        if (newCost > cost) {
+        if (newCost > cost && batch == -1) {
           throw new RuntimeException("failed to converge at iteration " + 
iterations + " with alpha " + alpha + " : cost going from " + cost + " to " + 
newCost);
         } else if (cost == newCost || newCost < threshold || iterations > 
MAX_ITERATIONS) {
           System.out.println("successfully converged after " + iterations + " 
iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " + 
newCost + " and parameters " + Arrays.toString(hypothesis.getParameters()));
@@ -87,7 +114,7 @@ public class BackPropagationLearningStra
         cost = newCost;
 
         // calculate the derivatives to update the parameters
-        RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, 
trainingExamples);
+        RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, 
samples);
 
         // calculate the updated parameters
         updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha);
@@ -132,9 +159,7 @@ public class BackPropagationLearningStra
 
         RealVector[] newActivations = new RealVector[activations.length];
         newActivations[0] = 
ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
-        for (int k = 0; k < activations.length - 1; k++) {
-          newActivations[k + 1] = activations[k];
-        }
+        System.arraycopy(activations, 0, newActivations, 1, activations.length 
- 1);
 
         // update triangle (big delta matrix)
         updateTriangle(triangle, newActivations, deltaVectors, 
weightsMatrixSet);

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1704348&r1=1704347&r2=1704348&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java
 Mon Sep 21 16:24:07 2015
@@ -43,16 +43,19 @@ public class LogisticRegressionCostFunct
   }
 
   @Override
-  public Double calculateAggregatedCost(TrainingSet<Double, Double> 
trainingExamples,
+  public Double calculateAggregatedCost(TrainingSet<Double, Double> 
trainingSet,
                                         Hypothesis<RealMatrix, Double, Double> 
hypothesis) throws Exception {
-
-    Double errorTerm = calculateErrorTerm(hypothesis, trainingExamples);
-    Double regularizationTerm = calculateRegularizationTerm(hypothesis, 
trainingExamples);
-    return errorTerm + regularizationTerm;
+    TrainingExample<Double, Double>[] samples = new 
TrainingExample[trainingSet.size()];
+    int i = 0;
+    for (TrainingExample<Double, Double> sample : trainingSet) {
+      samples[i] = sample;
+      i++;
+    }
+    return calculateCost(hypothesis, samples);
   }
 
   private Double calculateRegularizationTerm(Hypothesis<RealMatrix, Double, 
Double> hypothesis,
-                                             TrainingSet<Double, Double> 
trainingExamples) {
+                                             TrainingExample<Double, 
Double>... trainingExamples) {
     Double res = 1d;
     for (RealMatrix layerMatrix : hypothesis.getParameters()) {
       for (int i = 0; i < layerMatrix.getColumnDimension(); i++) {
@@ -63,12 +66,11 @@ public class LogisticRegressionCostFunct
         }
       }
     }
-    return (lambda / (2d * trainingExamples.size())) * res;
+    return (lambda / (2d * trainingExamples.length)) * res;
   }
 
   private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> 
hypothesis,
-                                    TrainingSet<Double, Double> 
trainingExamples) throws PredictionException,
-          CreationException {
+                                    TrainingExample<Double, Double>... 
trainingExamples) throws PredictionException, CreationException {
     Double res = 0d;
 
     for (TrainingExample<Double, Double> input : trainingExamples) {
@@ -78,11 +80,13 @@ public class LogisticRegressionCostFunct
       res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput)
               * Math.log(1d - predictedOutput);
     }
-    return (-1d / trainingExamples.size()) * res;
+    return (-1d / trainingExamples.length) * res;
   }
 
   @Override
-  public Double calculateCost(TrainingExample<Double, Double> trainingExample, 
Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception {
-    return null;
+  public Double calculateCost(Hypothesis<RealMatrix, Double, Double> 
hypothesis, TrainingExample<Double, Double>... trainingExamples) throws 
Exception {
+    Double errorTerm = calculateErrorTerm(hypothesis, trainingExamples);
+    Double regularizationTerm = calculateRegularizationTerm(hypothesis, 
trainingExamples);
+    return errorTerm + regularizationTerm;
   }
 }

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1704348&r1=1704347&r2=1704348&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java
 Mon Sep 21 16:24:07 2015
@@ -55,9 +55,12 @@ public class BackPropagationLearningStra
 
   @Test
   public void testLearningWithRandomNetworkAndRandomSettings() throws 
Exception {
-    BackPropagationLearningStrategy backPropagationLearningStrategy = new 
BackPropagationLearningStrategy(Math.random(),
-            Math.random(), new FeedForwardStrategy(Math.random() >= 0.5d ? new 
TanhFunction() : new SigmoidFunction()),
-            new LogisticRegressionCostFunction(Math.random()));
+    double alpha = Math.random();
+    double threshold = Math.random() * 0.001;
+    FeedForwardStrategy predictionStrategy = new 
FeedForwardStrategy(Math.random() >= 0.5d ? new TanhFunction() : new 
SigmoidFunction());
+    LogisticRegressionCostFunction costFunction = new 
LogisticRegressionCostFunction(Math.random());
+    BackPropagationLearningStrategy backPropagationLearningStrategy = new 
BackPropagationLearningStrategy(alpha, threshold,
+            predictionStrategy, costFunction);
 
     RealMatrix[] initialWeights = createRandomWeights();
 
@@ -67,6 +70,14 @@ public class BackPropagationLearningStra
     assertNotNull(learntWeights);
 
     for (int i = 0; i < learntWeights.length; i++) {
+      assertFalse("weights have not been changed", 
learntWeights[i].equals(initialWeights[i]));
+    }
+
+    backPropagationLearningStrategy = new 
BackPropagationLearningStrategy(alpha, 1, threshold, predictionStrategy, 
costFunction);
+    learntWeights = 
backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet);
+    assertNotNull(learntWeights);
+
+    for (int i = 0; i < learntWeights.length; i++) {
       assertFalse("weights have not been changed", 
learntWeights[i].equals(initialWeights[i]));
     }
   }

Modified: labs/yay/trunk/pom.xml
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1704348&r1=1704347&r2=1704348&view=diff
==============================================================================
--- labs/yay/trunk/pom.xml (original)
+++ labs/yay/trunk/pom.xml Mon Sep 21 16:24:07 2015
@@ -152,9 +152,8 @@
         <artifactId>maven-compiler-plugin</artifactId>
         <version>2.0.2</version>
         <configuration>
-          <compilerVersion>1.6</compilerVersion>
-          <source>1.6</source>
-          <target>1.6</target>
+          <source>1.8</source>
+          <target>1.8</target>
           <encoding>UTF-8</encoding>
         </configuration>
       </plugin>



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to