Author: tommaso Date: Wed Oct 28 11:43:25 2015 New Revision: 1710995 URL: http://svn.apache.org/viewvc?rev=1710995&view=rev Log: slightly parallelize weight matrix update
Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Modified: labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1710995&r1=1710994&r2=1710995&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java Wed Oct 28 11:43:25 2015 @@ -18,19 +18,15 @@ */ package org.apache.yay.core; -import java.util.Arrays; -import java.util.Iterator; - import org.apache.commons.math3.linear.Array2DRowRealMatrix; import org.apache.commons.math3.linear.RealMatrix; -import org.apache.yay.CostFunction; -import org.apache.yay.DerivativeUpdateFunction; -import org.apache.yay.LearningStrategy; -import org.apache.yay.NeuralNetwork; -import org.apache.yay.PredictionStrategy; -import org.apache.yay.TrainingExample; -import org.apache.yay.TrainingSet; -import org.apache.yay.WeightLearningException; +import org.apache.yay.*; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedList; +import java.util.concurrent.*; /** * Back propagation learning algorithm for neural networks implementation (see @@ -50,6 +46,8 @@ public class BackPropagationLearningStra private final int batch; private final int maxIterations; + private final ExecutorService executorService = Executors.newCachedThreadPool(); + public BackPropagationLearningStrategy(double alpha, double threshold, PredictionStrategy<Double, Double> predictionStrategy, CostFunction<RealMatrix, Double, Double> costFunction) { this(alpha, 1, threshold, predictionStrategy, costFunction, MAX_ITERATIONS); @@ -110,7 +108,7 @@ public class BackPropagationLearningStra } else if (iterations > 1 && (cost == newCost || newCost < threshold || iterations > maxIterations)) { System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + alpha + ",threshold:" + threshold + ") with cost " + newCost + " and parameters " + Arrays.toString(hypothesis.getParameters())); break; - } else if (Double.isNaN(newCost)){ + } else if (Double.isNaN(newCost)) { throw new RuntimeException("failed to converge at iteration " + iterations + " with alpha " + alpha + " : cost calculation underflow"); } @@ -139,14 +137,7 @@ public class BackPropagationLearningStra RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length]; for (int l = 0; l < weightsMatrixSet.length; l++) { double[][] updatedWeights = weightsMatrixSet[l].getData(); - for (int i = 0; i < updatedWeights.length; i++) { - for (int j = 0; j < updatedWeights[i].length; j++) { - double curVal = updatedWeights[i][j]; - if (!(i == 0 && curVal == 0d) && !(j == 0 && curVal == 1d)) { - updatedWeights[i][j] = updatedWeights[i][j] - alpha * derivatives[l].getData()[i][j]; - } - } - } + updateMatrix(derivatives, alpha, l, updatedWeights); if (updatedParameters[l] != null) { updatedParameters[l].setSubMatrix(updatedWeights, 0, 0); } else { @@ -156,4 +147,36 @@ public class BackPropagationLearningStra return updatedParameters; } + private void updateMatrix(final RealMatrix[] derivatives, final double alpha, final int l, final double[][] updatedWeights) { + Collection<Future<Double>> futures = new LinkedList<Future<Double>>(); + for (int i = 0; i < updatedWeights.length; i++) { + for (int j = 0; j < updatedWeights[i].length; j++) { + final int finalI = i; + final int finalJ = j; + Callable<Double> callable = new Callable<Double>() { + @Override + public Double call() throws Exception { + double curVal = updatedWeights[finalI][finalJ]; + double val; + if (!(finalI == 0 && curVal == 0d) && !(finalJ == 0 && curVal == 1d)) { + val = -alpha * derivatives[l].getData()[finalI][finalJ]; + updatedWeights[finalI][finalJ] = val; + } else { + val = curVal; + } + return val; + } + }; + futures.add(executorService.submit(callable)); + } + } + for (Future<Double> f : futures) { + try { + f.get(3, TimeUnit.SECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + } + } Modified: labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java?rev=1710995&r1=1710994&r2=1710995&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/core/BackPropagationLearningStrategyTest.java Wed Oct 28 11:43:25 2015 @@ -143,9 +143,9 @@ public class BackPropagationLearningStra RealMatrix[] learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); - assertFalse(learntWeights[0].equals(initialWeights[0])); - assertFalse(learntWeights[1].equals(initialWeights[1])); - assertFalse(learntWeights[2].equals(initialWeights[2])); + for (int i = 0; i < learntWeights.length; i++) { + assertFalse("weights have not been changed", learntWeights[i].equals(initialWeights[i])); + } backPropagationLearningStrategy = new BackPropagationLearningStrategy(BackPropagationLearningStrategy.DEFAULT_ALPHA, -1, BackPropagationLearningStrategy.DEFAULT_THRESHOLD, new FeedForwardStrategy(new SigmoidFunction()), @@ -154,9 +154,9 @@ public class BackPropagationLearningStra learntWeights = backPropagationLearningStrategy.learnWeights(initialWeights, trainingSet); assertNotNull(learntWeights); - assertFalse(learntWeights[0].equals(initialWeights[0])); - assertFalse(learntWeights[1].equals(initialWeights[1])); - assertFalse(learntWeights[2].equals(initialWeights[2])); + for (int i = 0; i < learntWeights.length; i++) { + assertFalse("weights have not been changed", learntWeights[i].equals(initialWeights[i])); + } } @Test --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org