Author: tommaso Date: Mon Oct 5 12:02:32 2015 New Revision: 1706816 URL: http://svn.apache.org/viewvc?rev=1706816&view=rev Log: added LMS, fixed backprop bias checks
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1706816&r1=1706815&r2=1706816&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Mon Oct 5 12:02:32 2015 @@ -38,9 +38,9 @@ import org.apache.yay.WeightLearningExce */ public class BackPropagationLearningStrategy implements LearningStrategy<Double, Double> { - public static final double DEFAULT_THRESHOLD = 0.005; + public static final double DEFAULT_THRESHOLD = 0.05; public static final int MAX_ITERATIONS = 100000; - public static final double DEFAULT_ALPHA = 0.03; + public static final double DEFAULT_ALPHA = 0.000003; private final PredictionStrategy<Double, Double> predictionStrategy; private final CostFunction<RealMatrix, Double, Double> costFunction; @@ -139,7 +139,7 @@ public class BackPropagationLearningStra for (int i = 0; i < updatedWeights.length; i++) { for (int j = 0; j < updatedWeights[i].length; j++) { double curVal = updatedWeights[i][j]; - if (curVal > 0d && curVal < 1d) { + if (!(i == 0 && curVal == 0d) && !(j == 0 && curVal == 1d)) { updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j]; } } Added: labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java?rev=1706816&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/LMSCostFunction.java Mon Oct 5 12:02:32 2015 @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay.core; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.yay.Hypothesis; +import org.apache.yay.NeuralNetworkCostFunction; +import org.apache.yay.TrainingExample; +import org.apache.yay.TrainingSet; + +/** + * Least mean square cost function + */ +public class LMSCostFunction implements NeuralNetworkCostFunction { + @Override + public Double calculateAggregatedCost(TrainingSet<Double, Double> trainingExamples, Hypothesis<RealMatrix, Double, Double> hypothesis) throws Exception { + TrainingExample<Double, Double>[] samples = new TrainingExample[trainingExamples.size()]; + int i = 0; + for (TrainingExample<Double, Double> sample : trainingExamples) { + samples[i] = sample; + i++; + } + return calculateCost(hypothesis, samples); + } + + @Override + public Double calculateCost(Hypothesis<RealMatrix, Double, Double> hypothesis, TrainingExample<Double, Double>... trainingExamples) throws Exception { + Double cost = 0d; + for (TrainingExample<Double, Double> example : trainingExamples) { + Double[] actualOutput = example.getOutput(); + Double[] predictedOutput = hypothesis.predict(example); + for (int i = 0; i < actualOutput.length; i++) { + cost += actualOutput[i] - predictedOutput[i]; + } + } + return Math.pow(cost, 2) / 2; + } +} Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1706816&r1=1706815&r2=1706816&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Mon Oct 5 12:02:32 2015 @@ -137,7 +137,7 @@ public class BackPropagationLearningStra initialWeights[1] = new Array2DRowRealMatrix(new double[][]{{0d, 0d, 0d, 0d}, {1d, 0.5d, 1d, 0.5d}, {1d, 0.1d, 8d, 0.1d}, {1d, 0.1d, 8d, 0.2d}}); // 4 x 4 initialWeights[2] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 0.3d, 0.5d}}); // 1 x 4 - Collection<TrainingExample<Double, Double>> samples = createSamples(10000, 2, 1); + Collection<TrainingExample<Double, Double>> samples = createSamples(100, 2, 1); TrainingSet<Double, Double> trainingSet = new TrainingSet<Double, Double>(samples); RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java?rev=1706816&r1=1706815&r2=1706816&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkIntegrationTest.java Mon Oct 5 12:02:32 2015 @@ -186,7 +186,7 @@ public class NeuralNetworkIntegrationTes if (c == 0) { d[0][c] = 1d; } else { - d[0][c] = r.nextInt(100) / 101d;; + d[0][c] = r.nextInt(100) / 101d; } } else { d[0][c] = 0; @@ -199,7 +199,7 @@ public class NeuralNetworkIntegrationTes if (j == 0) { val = 1d; } else { - val = r.nextInt(100) / 101d;; + val = r.nextInt(100) / 101d; } d[k][j] = val; } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org