Author: tommaso Date: Thu Jan 14 16:19:32 2016 New Revision: 1724646 URL: http://svn.apache.org/viewvc?rev=1724646&view=rev Log: pruned old api, minimal shallow ff nn impl
Added: labs/yay/trunk/core/src/main/java/org/apache/yay/ActivationFunction.java (with props) labs/yay/trunk/core/src/main/java/org/apache/yay/IdentityActivationFunction.java (with props) labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java (with props) labs/yay/trunk/core/src/main/java/org/apache/yay/SigmoidFunction.java - copied, changed from r1714192, labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/SoftmaxActivationFunction.java - copied, changed from r1715735, labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/StepActivationFunction.java - copied, changed from r1714192, labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java labs/yay/trunk/core/src/main/java/org/apache/yay/TanhFunction.java - copied, changed from r1714192, labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java (with props) labs/yay/trunk/core/src/test/java/org/apache/yay/SigmoidFunctionTest.java - copied, changed from r1707760, labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java Removed: labs/yay/trunk/api/ labs/yay/trunk/core/src/main/java/org/apache/yay/core/ labs/yay/trunk/core/src/test/java/org/apache/yay/core/ Modified: labs/yay/trunk/core/pom.xml labs/yay/trunk/pom.xml Modified: labs/yay/trunk/core/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1724646&r1=1724645&r2=1724646&view=diff ============================================================================== --- labs/yay/trunk/core/pom.xml (original) +++ labs/yay/trunk/core/pom.xml Thu Jan 14 16:19:32 2016 @@ -22,17 +22,12 @@ <parent> <groupId>org.apache.yay</groupId> <artifactId>yay-parent</artifactId> - <version>0.1-SNAPSHOT</version> + <version>0.2-SNAPSHOT</version> <relativePath>../</relativePath> </parent> <name>Yay core</name> <dependencies> <dependency> - <groupId>org.apache.yay</groupId> - <artifactId>yay-api</artifactId> - <version>${project.version}</version> - </dependency> - <dependency> <groupId>org.mockito</groupId> <artifactId>mockito-core</artifactId> <version>1.9.5</version> Added: labs/yay/trunk/core/src/main/java/org/apache/yay/ActivationFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/ActivationFunction.java?rev=1724646&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/ActivationFunction.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/ActivationFunction.java Thu Jan 14 16:19:32 2016 @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.math3.linear.RealMatrix; + +/** + * An activation function AF : S -* S receives a signal and generates a new signal. + * An activation function AF has horizontal asymptotes at 0 and 1 and a non + * decreasing first derivative AF' with AF and AF' both being computable. + * These are usually used in neurons in order to propagate the "signal" + * throughout the whole network. + */ +public interface ActivationFunction { + + /** + * Apply this <code>ActivationFunction</code> to the given matrix of signals, generating a new matrix of transformed + * signals. + * + * @param weights the matrix of weights the activation should be applied to + * @return the output signal generated as a {@link RealMatrix} + */ + RealMatrix applyMatrix(RealMatrix weights); + +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/ActivationFunction.java ------------------------------------------------------------------------------ svn:eol-style = native Added: labs/yay/trunk/core/src/main/java/org/apache/yay/IdentityActivationFunction.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/IdentityActivationFunction.java?rev=1724646&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/IdentityActivationFunction.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/IdentityActivationFunction.java Thu Jan 14 16:19:32 2016 @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.yay.ActivationFunction; + +/** + * An {@link org.apache.yay.ActivationFunction} which just replicates the signal without changing it + */ +public class IdentityActivationFunction implements ActivationFunction { + + @Override + public RealMatrix applyMatrix(RealMatrix weights) { + return weights; + } + +} Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/IdentityActivationFunction.java ------------------------------------------------------------------------------ svn:eol-style = native Added: labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java?rev=1724646&view=auto ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java (added) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java Thu Jan 14 16:19:32 2016 @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.apache.commons.lang3.SystemUtils; +import org.apache.commons.math3.linear.ArrayRealVector; +import org.apache.commons.math3.linear.MatrixUtils; +import org.apache.commons.math3.linear.RealMatrix; +import org.apache.commons.math3.linear.RealMatrixChangingVisitor; +import org.apache.commons.math3.linear.RealVector; + +import java.util.Arrays; +import java.util.Random; + +/** + * A shallow feed forward neural network. + * It learns its weights through backpropagation algorithm via stochastic gradient descent applied to a collection of + * training samples. + * Each example is a real vectors whose first N elements (identified by the no. of network input units) are the actual + * outputs and the remaining elements are the input features. + */ +public class ShallowFeedForwardNeuralNetwork { + + private final Configuration configuration; + + /** + * Each RealMatrix maps weights between two layers. + * E.g.: weights[0] controls function mapping from layer 0 to layer 1. + * If network has 4 units in layer 1 and 5 units in layer 2, then weights[0] will be of dimension 5x4, plus bias terms. + * A network having layers with 3, 4 and 2 units each will have the following weights matrix dimensions: + * - weights[0] : 4x3 + * - weights[1] : 2x4 + * <p> + * the first row of weighs[0]Â matrix holds the weights of each neuron in the first neuron of the second layer, + * the second row of weighs[0]Â holds the weights of each neuron in the second neuron of the second layer, etc. + */ + private RealMatrix[] weights; + + public ShallowFeedForwardNeuralNetwork(Configuration configuration) { + this.configuration = configuration; + initialize(); + } + + private void initialize() { + weights = createRandomWeights(); + } + + private RealMatrix[] createRandomWeights() { + Random r = new Random(); + int[] layers = new int[configuration.layers.length]; + for (int i = 0; i < layers.length; i++) { + layers[i] = configuration.layers[i] + (i < layers.length - 1 ? 1 : 0); + } + int weightsCount = layers.length - 1; + + RealMatrix[] initialWeights = new RealMatrix[weightsCount]; + + for (int i = 0; i < weightsCount; i++) { + + RealMatrix matrix = MatrixUtils.createRealMatrix(layers[i + 1], layers[i]); + final int finalI = i; + matrix.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + if (finalI != weightsCount - 1 && row == 0) { + return 0d; + } else if (column == 0) { + return 1d; + } + return r.nextInt(100) / 101d; + } + + @Override + public double end() { + return 0; + } + }); + + initialWeights[i] = matrix; + } + return initialWeights; + + } + + // --- stochastic gradient descent --- + + /** + * perform weights learning from the training examples using stochastic gradient descent algorithm + * + * @param samples the training examples + * @return the final cost with the updated weights + * @throws Exception if SGD fails to converge or any numerical error happens + */ + public double learnWeights(double[]... samples) throws Exception { + double newCost; + int iterations = 0; + + double cost = Double.MAX_VALUE; + long start = System.currentTimeMillis(); + while (true) { + if (iterations % (1 + (configuration.maxIterations / 100)) == 0) { + long time = (System.currentTimeMillis() - start) / 1000; + if (time > 60) { + System.out.println("cost is " + cost + " after " + iterations + " iterations in " + (time / 60) + " minutes (" + ((double) iterations / time) + " ips)"); + } + } + // current training example + double[] sample = samples[iterations % samples.length]; + + int outputLayerSize = configuration.layers[configuration.layers.length - 1]; + double[] expectedOutput = getSampleOutput(sample, outputLayerSize); + double[] input = getSampleInput(sample, outputLayerSize); + + double[] predictedOutput = predictOutput(input); // TODO : use debugOutput to avoid performing it again when calculating derivatives + + // calculate cost + newCost = calculateCost(expectedOutput, predictedOutput, samples.length); + + if (Double.POSITIVE_INFINITY == newCost) { + throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost going from " + cost + " to " + newCost); + } else if (iterations > 1 && (cost == newCost || newCost < configuration.threshold || iterations > configuration.maxIterations)) { + System.out.println("successfully converged after " + (iterations - 1) + " iterations (alpha:" + configuration.alpha + ",threshold:" + configuration.threshold + ") with cost " + newCost); + break; + } else if (Double.isNaN(newCost)) { + throw new Exception("failed to converge at iteration " + iterations + " with alpha " + configuration.alpha + " : cost calculation underflow"); + } + + // update registered cost + cost = newCost; + + // calculate the derivatives to update the parameters + RealMatrix[] derivatives = calculateDerivatives(input, expectedOutput, sample.length); + + // update the weights + weights = getUpdatedWeights(derivatives); + + iterations++; + } + return newCost; + } + + // --- sample parsing --- + + private double[] getSampleInput(double[] sample, int outputLayerSize) { + double[] input = Arrays.copyOfRange(sample, outputLayerSize, sample.length); + double[] result = new double[input.length + 1]; + result[0] = 1d; + System.arraycopy(input, 0, result, 1, input.length); + return result; + } + + private double[] getSampleOutput(double[] sample, int outputLayerSize) { + return Arrays.copyOfRange(sample, 0, outputLayerSize); + } + + // --- backpropagation --- + + private RealMatrix[] calculateDerivatives(double[] input, double[] output, int size) throws Exception { + RealVector[] deltaVectors = new RealVector[weights.length]; + RealMatrix[] deltas = new RealMatrix[weights.length]; + RealMatrix[] ds = new RealMatrix[weights.length]; + + // compute delta vectors + + // get activations from feed forward propagation + RealVector[] activations = applyFeedForward(input); + + // calculate output error (corresponding to the last delta vector) + RealVector nextLayerDelta = calculateOutputError(output, activations); + + deltaVectors[weights.length - 1] = nextLayerDelta; + + for (int l = weights.length - 1; l > 0; l--) { + RealVector currentActivationsVector = activations[l - 1]; + nextLayerDelta = calculateDeltaVector(weights[l], currentActivationsVector, nextLayerDelta); + + // collect delta vectors for this example + deltaVectors[l - 1] = nextLayerDelta; + } + + RealVector[] newActivations = new RealVector[activations.length]; + newActivations[0] = MatrixUtils.createRealVector(input); + System.arraycopy(activations, 0, newActivations, 1, activations.length - 1); + + // compute deltas + for (int l = deltas.length - 1; l >= 0; l--) { + RealMatrix realMatrix = deltaVectors[l].outerProduct(newActivations[l]); + if (deltas[l] == null) { + deltas[l] = realMatrix; + } else { + deltas[l] = deltas[l].add(realMatrix); + } + } + + // compute derivatives + for (int i = 0; i < deltas.length; i++) { + ds[i] = deltas[i].scalarMultiply(1d / size); + } + + // regularization + int l = 0; + for (RealMatrix d : ds) { + final int finalL = l; + d.walkInOptimizedOrder(new RealMatrixChangingVisitor() { + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + if (column != 0) { + return value + configuration.alpha * weights[finalL].getEntry(row, column); + } else { + return value; + } + } + + @Override + public double end() { + return 0; + } + }); + l++; + } + + return ds; + } + + private RealVector calculateDeltaVector(RealMatrix weight, RealVector activationsVector, RealVector nextLayerDelta) { + double[] ones = new double[activationsVector.getDimension()]; + Arrays.fill(ones, 1d); + return weight.preMultiply(nextLayerDelta).ebeMultiply(activationsVector.ebeMultiply(MatrixUtils.createRealVector(ones).subtract(activationsVector))); + } + + private RealVector calculateOutputError(double[] expectedOutput, RealVector[] activations) { + RealVector output = activations[activations.length - 1]; + RealVector learnedOutputRealVector = new ArrayRealVector(expectedOutput); + return output.subtract(learnedOutputRealVector); + } + + private RealMatrix[] getUpdatedWeights(final RealMatrix[] derivatives) { + RealMatrix[] updatedParameters = new RealMatrix[weights.length]; + + for (int l = 0; l < weights.length; l++) { + RealMatrix realMatrix = weights[l].copy(); + final int finalL = l; + RealMatrixChangingVisitor visitor = new RealMatrixChangingVisitor() { + + @Override + public void start(int rows, int columns, int startRow, int endRow, int startColumn, int endColumn) { + + } + + @Override + public double visit(int row, int column, double value) { + if (!(row == 0 && value == 0d) && !(column == 0 && value == 1d)) { + return value - configuration.alpha * derivatives[finalL].getEntry(row, column); + } else { + return value; + } + } + + @Override + public double end() { + return 0; + } + }; + realMatrix.walkInOptimizedOrder(visitor); + updatedParameters[l] = realMatrix; + } + return updatedParameters; + } + + // --- cost --- + + private double calculateCost(double[] expectedOutput, double[] predictedOutput, int size) { + // neural network cost function + double res = 0d; + for (int i = 0; i < predictedOutput.length; i++) { + Double yo = expectedOutput[i]; + Double ho = predictedOutput[i]; + res += yo * Math.log(ho) + (1d - yo) + * Math.log(1d - ho); + } + return (-1d / size) * res; + } + + // --- feed forward --- + + public double[] predictOutput(double[] input) throws Exception { + if (input.length == configuration.layers[0]) { + double[] i2 = new double[input.length + 1]; + i2[0] = 1d; + System.arraycopy(input, 0, i2, 1, input.length); + input = i2; + } + RealVector[] debugOutput = applyFeedForward(input); + RealVector d = debugOutput[debugOutput.length - 1]; + return d.toArray(); + } + + private RealVector[] applyFeedForward(double[] input) throws Exception { + if (weights == null) { + throw new Exception("weights undefined, perform learning first"); + } + + RealVector[] debugOutput = new RealVector[weights.length]; + + RealMatrix x = MatrixUtils.createRowRealMatrix(input); + for (int w = 0; w < weights.length; w++) { + // compute matrix multiplication + x = x.multiply(weights[w].transpose()); + + // get activation function for w-th layer + int idx = configuration.activationFunctions.length == weights.length ? w : 0; + + // apply the activation function to each element in the matrix + x = configuration.activationFunctions[idx].applyMatrix(x.getRowMatrix(0)); + x = new SigmoidFunction().applyMatrix(x.getRowMatrix(0)); + + debugOutput[w] = x.getRowVector(0); + } + return debugOutput; + } + + // --- neural network configuration --- + + public static class Configuration { + protected int maxIterations; + protected double alpha; + protected double threshold; + protected int[] layers; + protected ActivationFunction[] activationFunctions; + } +} \ No newline at end of file Propchange: labs/yay/trunk/core/src/main/java/org/apache/yay/ShallowFeedForwardNeuralNetwork.java ------------------------------------------------------------------------------ svn:eol-style = native Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/SigmoidFunction.java (from r1714192, labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SigmoidFunction.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/SigmoidFunction.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java&r1=1714192&r2=1724646&rev=1724646&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SigmoidFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SigmoidFunction.java Thu Jan 14 16:19:32 2016 @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.yay.core; +package org.apache.yay; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrixChangingVisitor; @@ -25,13 +25,9 @@ import org.apache.yay.ActivationFunction /** * See <a href="http://en.wikipedia.org/wiki/Sigmoid_function">here</a>. */ -public class SigmoidFunction implements ActivationFunction<Double> { +public class SigmoidFunction implements ActivationFunction { - public Double apply(RealMatrix matrix, final Double input) { - return sigmoid(input); - } - - private double sigmoid(Double input) { + public double apply(Double input) { return 1d / (1d + Math.exp(-1d * input)); } @@ -46,7 +42,7 @@ public class SigmoidFunction implements @Override public double visit(int row, int column, double value) { - return sigmoid(value); + return apply(value); } @Override Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/SoftmaxActivationFunction.java (from r1715735, labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/SoftmaxActivationFunction.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/SoftmaxActivationFunction.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java&r1=1715735&r2=1724646&rev=1724646&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/SoftmaxActivationFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/SoftmaxActivationFunction.java Thu Jan 14 16:19:32 2016 @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.yay.core; +package org.apache.yay; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrixChangingVisitor; @@ -26,14 +26,7 @@ import org.apache.yay.ActivationFunction /** * Softmax activation function */ -public class SoftmaxActivationFunction implements ActivationFunction<Double> { - - @Override - public Double apply(RealMatrix weights, Double signal) { - double num = Math.exp(signal); - double den = expDen(weights); - return num / den; - } +public class SoftmaxActivationFunction implements ActivationFunction { public RealMatrix applyMatrix(RealMatrix weights) { RealMatrix matrix = weights.copy(); Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/StepActivationFunction.java (from r1714192, labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/StepActivationFunction.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/StepActivationFunction.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java&r1=1714192&r2=1724646&rev=1724646&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/StepActivationFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/StepActivationFunction.java Thu Jan 14 16:19:32 2016 @@ -16,16 +16,15 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.yay.core; +package org.apache.yay; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrixChangingVisitor; -import org.apache.yay.ActivationFunction; /** * An {@link org.apache.yay.ActivationFunction} implementing a step function of the input */ -public class StepActivationFunction implements ActivationFunction<Double> { +public class StepActivationFunction implements ActivationFunction { private final double center; @@ -33,11 +32,6 @@ public class StepActivationFunction impl this.center = center; } - @Override - public Double apply(RealMatrix matrix, Double signal) { - return step(signal); - } - private double step(Double signal) { return signal >= center ? 1d : 0d; } Copied: labs/yay/trunk/core/src/main/java/org/apache/yay/TanhFunction.java (from r1714192, labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/TanhFunction.java?p2=labs/yay/trunk/core/src/main/java/org/apache/yay/TanhFunction.java&p1=labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java&r1=1714192&r2=1724646&rev=1724646&view=diff ============================================================================== --- labs/yay/trunk/core/src/main/java/org/apache/yay/core/TanhFunction.java (original) +++ labs/yay/trunk/core/src/main/java/org/apache/yay/TanhFunction.java Thu Jan 14 16:19:32 2016 @@ -16,7 +16,7 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.yay.core; +package org.apache.yay; import org.apache.commons.math3.linear.RealMatrix; import org.apache.commons.math3.linear.RealMatrixChangingVisitor; @@ -25,11 +25,7 @@ import org.apache.yay.ActivationFunction /** * Tanh activation function */ -public class TanhFunction implements ActivationFunction<Double> { - @Override - public Double apply(RealMatrix matrix, Double signal) { - return Math.tanh(signal); - } +public class TanhFunction implements ActivationFunction { @Override public RealMatrix applyMatrix(RealMatrix weights) { Added: labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java?rev=1724646&view=auto ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java (added) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java Thu Jan 14 16:19:32 2016 @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.yay; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests for {@link ShallowFeedForwardNeuralNetwork} + */ +public class ShallowFeedForwardNeuralNetworkTest { + + @Test + public void testLearnAndPredict() throws Exception { + ShallowFeedForwardNeuralNetwork.Configuration configuration = new ShallowFeedForwardNeuralNetwork.Configuration(); + configuration.alpha = 0.0001d; + configuration.layers = new int[]{3, 4, 1}; + configuration.maxIterations = 10000; + configuration.threshold = 0.004d; + configuration.activationFunctions = new ActivationFunction[]{new SigmoidFunction()}; + + ShallowFeedForwardNeuralNetwork neuralNetwork = new ShallowFeedForwardNeuralNetwork(configuration); + + assertNotNull(neuralNetwork); + double[][] samples = new double[3][4]; + samples[0][0] = 0.1d; + samples[0][1] = 0.2d; + samples[0][2] = 0.3d; + samples[0][3] = 0.4d; + samples[1][0] = 0.5d; + samples[1][1] = 0.6d; + samples[1][2] = 0.7d; + samples[1][3] = 0.8d; + samples[2][0] = 0.9d; + samples[2][1] = 0.1d; + samples[2][2] = 0.2d; + samples[2][3] = 0.3d; + double cost = neuralNetwork.learnWeights(samples); + assertTrue(cost > 0 && cost < 10); + + double[] doubles = neuralNetwork.predictOutput(new double[]{0.7d, 0.8d, 0.9d}); + assertNotNull(doubles); + + assertEquals(0.4d, doubles[0], 0.4d); + } +} \ No newline at end of file Propchange: labs/yay/trunk/core/src/test/java/org/apache/yay/ShallowFeedForwardNeuralNetworkTest.java ------------------------------------------------------------------------------ svn:eol-style = native Copied: labs/yay/trunk/core/src/test/java/org/apache/yay/SigmoidFunctionTest.java (from r1707760, labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java) URL: http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/SigmoidFunctionTest.java?p2=labs/yay/trunk/core/src/test/java/org/apache/yay/SigmoidFunctionTest.java&p1=labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java&r1=1707760&r2=1724646&rev=1724646&view=diff ============================================================================== --- labs/yay/trunk/core/src/test/java/org/apache/yay/core/SigmoidFunctionTest.java (original) +++ labs/yay/trunk/core/src/test/java/org/apache/yay/SigmoidFunctionTest.java Thu Jan 14 16:19:32 2016 @@ -16,33 +16,34 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.yay.core; +package org.apache.yay; +import org.apache.yay.SigmoidFunction; import org.junit.Test; import static org.junit.Assert.assertEquals; /** - * Tests for {@link org.apache.yay.core.SigmoidFunction} + * Tests for {@link SigmoidFunction} */ public class SigmoidFunctionTest { @Test public void testCorrectOutput() throws Exception { SigmoidFunction sigmoidFunction = new SigmoidFunction(); - Double output = sigmoidFunction.apply(null, 38d); + Double output = sigmoidFunction.apply(38d); assertEquals(Double.valueOf(1d), output); - output = sigmoidFunction.apply(null, 6d); + output = sigmoidFunction.apply(6d); assertEquals(Double.valueOf(0.9975273768433653d), output); - output = sigmoidFunction.apply(null, 2.5d); + output = sigmoidFunction.apply(2.5d); assertEquals(Double.valueOf(0.9241418199787566d), output); - output = sigmoidFunction.apply(null, -2.5d); + output = sigmoidFunction.apply(-2.5d); assertEquals(Double.valueOf(0.07585818002124355d), output); - output = sigmoidFunction.apply(null, 0d); + output = sigmoidFunction.apply(0d); assertEquals(Double.valueOf(0.5d), output); } } Modified: labs/yay/trunk/pom.xml URL: http://svn.apache.org/viewvc/labs/yay/trunk/pom.xml?rev=1724646&r1=1724645&r2=1724646&view=diff ============================================================================== --- labs/yay/trunk/pom.xml (original) +++ labs/yay/trunk/pom.xml Thu Jan 14 16:19:32 2016 @@ -5,7 +5,7 @@ <groupId>org.apache.yay</groupId> <artifactId>yay-parent</artifactId> <packaging>pom</packaging> - <version>0.1-SNAPSHOT</version> + <version>0.2-SNAPSHOT</version> <name>Yay parent</name> <url>http://svn.apache.org/repos/asf/labs/yay</url> <organization> @@ -206,7 +206,6 @@ </resources> </build> <modules> - <module>api</module> <module>core</module> </modules> <profiles> --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org For additional commands, e-mail: commits-h...@labs.apache.org