Author: tommaso Date: Mon Sep 21 16:24:07 2015 New Revision: 1704348 URL: http://svn.apache.org/viewvc?rev=1704348&view=rev Log: switch from batch to stochastic GD in backprop
Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java labs/yay/trunk/pom.xml Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java?rev=1704348&r1=1704347&r2=1704348&view=diff ============================================================================== --- labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java (original) +++ labs/yay/trunk/api/src/main/java/org/apache/yay/CostFunction.java Mon Sep 21 16:24:07 2015 @@ -39,14 +39,13 @@ public interface CostFunction<T, I, O> { Hypothesis<T, I, O> hypothesis) throws Exception; /** - * Calculate the cost of a single {@link org.apache.yay.TrainingExample} for a given {@link org.apache.yay.Hypothesis} + * Calculate the cost of one or more {@link org.apache.yay.TrainingExample}s for a given {@link org.apache.yay.Hypothesis} * - * @param trainingExample the training example * @param hypothesis the hypothesis + * @param trainingExamples some training examples * @return a <code>Double</code> cost * @throws Exception if any error occurs during the cost calculation */ - Double calculateCost(TrainingExample<I, O> trainingExample, - Hypothesis<T, I, O> hypothesis) throws Exception; + Double calculateCost(Hypothesis<T, I, O> hypothesis, TrainingExample<I, O>... trainingExamples) throws Exception; } Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1704348&r1=1704347&r2=1704348&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Mon Sep 21 16:24:07 2015 @@ -19,6 +19,8 @@ package org.apache.yay.core; import java.util.Arrays; +import java.util.Iterator; + import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.ArrayRealVector; import org.apache.commons.math3.linear.RealMatrix; @@ -46,13 +48,21 @@ public class BackPropagationLearningStra private final CostFunction<RealMatrix, Double, Double> costFunction; private final double alpha; private final double threshold; + private final int batch; + + public BackPropagationLearningStrategy(double alpha, double threshold, PredictionStrategy<Double, Double> predictionStrategy, + CostFunction<RealMatrix, Double, Double> costFunction) { + this(alpha, 1, threshold, predictionStrategy, costFunction); + } - public BackPropagationLearningStrategy(double alpha, double threshold, PredictionStrategy<Double, Double> predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) { + public BackPropagationLearningStrategy(double alpha, int batch, double threshold, PredictionStrategy<Double, Double> predictionStrategy, + CostFunction<RealMatrix, Double, Double> costFunction) { this.predictionStrategy = predictionStrategy; this.costFunction = costFunction; this.alpha = alpha; this.threshold = threshold; + this.batch = batch; } public BackPropagationLearningStrategy() { @@ -61,6 +71,7 @@ public class BackPropagationLearningStra this.costFunction = new LogisticRegressionCostFunction(); this.alpha = DEFAULT_ALPHA; this.threshold = DEFAULT_THRESHOLD; + this.batch = 1; } @Override @@ -70,13 +81,29 @@ public class BackPropagationLearningStra int iterations = 0; NeuralNetwork hypothesis = NeuralNetworkFactory.create(weightsMatrixSet, new VoidLearningStrategy<Double, Double>(), predictionStrategy, new MaxSelectionFunction<Double>()); + Iterator<TrainingExample<Double, Double>> iterator = trainingExamples.iterator(); double cost = Double.MAX_VALUE; while (true) { + + TrainingSet<Double, Double> samples; + if (batch == -1) { + samples = trainingExamples; + } else { + TrainingExample<Double, Double>[] miniBatch = new TrainingExample[batch]; + for (int i = 0; i < batch; i++) { + if (!iterator.hasNext()) { + iterator = trainingExamples.iterator(); + } + miniBatch[i] = iterator.next(); + } + samples = new TrainingSet<>(Arrays.asList(miniBatch)); + } + // calculate cost - double newCost = costFunction.calculateAggregatedCost(trainingExamples, hypothesis); + double newCost = costFunction.calculateAggregatedCost(samples, hypothesis); - if (newCost > cost) { + if (newCost > cost && batch == -1) { throw new RuntimeException("failed to converge at iteration " + iterations + " with alpha " + alpha + " : cost going from " + cost + " to " + newCost); } else if (cost == newCost || newCost < threshold || iterations > MAX_ITERATIONS) { System.out.println("successfully converged after " + iterations + " iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters " + Arrays.toString(hypothesis.getParameters())); @@ -87,7 +114,7 @@ public class BackPropagationLearningStra cost = newCost; // calculate the derivatives to update the parameters - RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, trainingExamples); + RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, samples); // calculate the updated parameters updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha); @@ -132,9 +159,7 @@ public class BackPropagationLearningStra RealVector[] newActivations = new RealVector[activations.length]; newActivations[0] = ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); - for (int k = 0; k < activations.length - 1; k++) { - newActivations[k + 1] = activations[k]; - } + System.arraycopy(activations, 0, newActivations, 1, activations.length - 1); // update triangle (big delta matrix) updateTriangle(triangle, newActivations, deltaVectors, weightsMatrixSet); Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java?rev=1704348&r1=1704347&r2=1704348&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LogisticRegressionCostFunction.java Mon Sep 21 16:24:07 2015 @@ -43,16 +43,19 @@ public class LogisticRegressionCostFunct } @Override - public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingExamples, + public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingSet, Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception { - - Double errorTerm = calculateErrorTerm(hypothesis, trainingExamples); - Double regularizationTerm = calculateRegularizationTerm(hypothesis, trainingExamples); - return errorTerm + regularizationTerm; + TrainingExample<Double, Double>[] samples = new TrainingExample[trainingSet.size()]; + int i = 0; + for (TrainingExample<Double, Double> sample : trainingSet) { + samples[i] = sample; + i++; + } + return calculateCost(hypothesis, samples); } private Double calculateRegularizationTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, - TrainingSet<Double, Double> trainingExamples) { + TrainingExample<Double, Double>... trainingExamples) { Double res = 1d; for (RealMatrix layerMatrix : hypothesis.getParameters()) { for (int i = 0; i < layerMatrix.getColumnDimension(); i++) { @@ -63,12 +66,11 @@ public class LogisticRegressionCostFunct } } } - return (lambda / (2d * trainingExamples.size())) * res; + return (lambda / (2d * trainingExamples.length)) * res; } private Double calculateErrorTerm(Hypothesis<RealMatrix, Double, Double> hypothesis, - TrainingSet<Double, Double> trainingExamples) throws PredictionException, - CreationException { + TrainingExample<Double, Double>... trainingExamples) throws PredictionException, CreationException { Double res = 0d; for (TrainingExample<Double, Double> input : trainingExamples) { @@ -78,11 +80,13 @@ public class LogisticRegressionCostFunct res += sampleOutput * Math.log(predictedOutput) + (1d - sampleOutput) * Math.log(1d - predictedOutput); } - return (-1d / trainingExamples.size()) * res; + return (-1d / trainingExamples.length) * res; } @Override - public Double calculateCost(TrainingExample<Double, Double> trainingExample, Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception { - return null; + public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { + Double errorTerm = calculateErrorTerm(hypothesis, trainingExamples); + Double regularizationTerm = calculateRegularizationTerm(hypothesis, trainingExamples); + return errorTerm + regularizationTerm; } } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1704348&r1=1704347&r2=1704348&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Mon Sep 21 16:24:07 2015 @@ -55,9 +55,12 @@ public class BackPropagationLearningStra @Test public void testLearningWithRandomNetworkAndRandomSettings() throws Exception { - BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy(Math.random(), - Math.random(), new FeedForwardStrategy(Math.random() >= 0.5d ? new TanhFunction() : new SigmoidFunction()), - new LogisticRegressionCostFunction(Math.random())); + double alpha = Math.random(); + double threshold = Math.random() * 0.001; + FeedForwardStrategy predictionStrategy = new FeedForwardStrategy(Math.random() >= 0.5d ? new TanhFunction() : new SigmoidFunction()); + LogisticRegressionCostFunction costFunction = new LogisticRegressionCostFunction(Math.random()); + BackPropagationLearningStrategy backPropagationLearningStrategy = new BackPropagationLearningStrategy(alpha, threshold, + predictionStrategy, costFunction); RealMatrix[] initialWeights = createRandomWeights(); @@ -67,6 +70,14 @@ public class BackPropagationLearningStra assertNotNull(learntWeights); for (int i = 0; i < learntWeights.length; i++) { + assertFalse("weights have not been changed", learntWeights[i].equals(initialWeights[i])); + } + + backPropagationLearningStrategy = new BackPropagationLearningStrategy(alpha, 1, threshold, predictionStrategy, costFunction); + learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); + assertNotNull(learntWeights); + + for (int i = 0; i < learntWeights.length; i++) { assertFalse("weights have not been changed", learntWeights[i].equals(initialWeights[i])); } } Modified: labs/yay/trunk/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1704348&r1=1704347&r2=1704348&view=diff ============================================================================== --- labs/yay/trunk/pom.xml (original) +++ labs/yay/trunk/pom.xml Mon Sep 21 16:24:07 2015 @@ -152,9 +152,8 @@ <artifactId>maven-compiler-plugin</artifactId> <version>2.0.2</version> <configuration> - <compilerVersion>1.6</compilerVersion> - <source>1.6</source> - <target>1.6</target> + <source>1.8</source> + <target>1.8</target> <encoding>UTF-8</encoding> </configuration> </plugin> --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org