I need help. I want out On Sep 28, 2015 11:50 AM, <tomm...@apache.org> wrote:
> Author: tommaso > Date: Mon Sep 28 16:49:57 2015 > New Revision: 1705721 > > URL: http://svn.apache.org/viewvc?rev=1705721&view=rev > Log: > switch from batch to stochastic GD in backprop, abstracted derivative > update function > > Added: > > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > Removed: > labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java > Modified: > labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > > Added: > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto > > ============================================================================== > --- > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > (added) > +++ > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > Mon Sep 28 16:49:57 2015 > @@ -0,0 +1,30 @@ > +/* > + * Licensed to the Apache Software Foundation (ASF) under one > + * or more contributor license agreements. See the NOTICE file > + * distributed with this work for additional information > + * regarding copyright ownership. The ASF licenses this file > + * to you under the Apache License, Version 2.0 (the > + * "License"); you may not use this file except in compliance > + * with the License. You may obtain a copy of the License at > + * > + * http://www.apache.org/licenses/LICENSE-2.0 > + * > + * Unless required by applicable law or agreed to in writing, > + * software distributed under the License is distributed on an > + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY > + * KIND, either express or implied. See the License for the > + * specific language governing permissions and limitations > + * under the License. > + */ > +package org.apache.yay; > + > +import org.apache.commons.math3.linear.RealMatrix; > +import org.apache.yay.TrainingSet; > + > +/** > + * Derivatives update function > + */ > +public interface DerivativeUpdateFunction<F,O> { > + > + RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O> > trainingSet); > +} > > Modified: > labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > ============================================================================== > --- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java > (original) > +++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Mon > Sep 28 16:49:57 2015 > @@ -21,8 +21,19 @@ package org.apache.yay; > import org.apache.commons.math3.linear.RealMatrix; > > /** > - * A neural network is a layered connected graph of elaboration units > + * A Neural Network is a layered connected graph of elaboration units. > + * > + * It takes a double vector as input and produces a double vector as > output. > */ > public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, > Double> { > > + /** > + * Predict the output for a given input > + * > + * @param input the input to evaluate > + * @return the predicted output > + * @throws PredictionException if any error occurs during the > prediction phase > + */ > + Double[] getOutputVector(Input<Double> input) throws > PredictionException; > + > } > > Modified: > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > ============================================================================== > --- > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > (original) > +++ > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > Mon Sep 28 16:49:57 2015 > @@ -19,6 +19,7 @@ > package org.apache.yay.core; > > import java.util.Arrays; > +import java.util.DoubleSummaryStatistics; > import java.util.Iterator; > > import org.apache.commons.math3.linear.Array2DRowRealMatrix; > @@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A > import org.apache.commons.math3.linear.RealMatrix; > import org.apache.commons.math3.linear.RealVector; > import org.apache.yay.CostFunction; > +import org.apache.yay.DerivativeUpdateFunction; > import org.apache.yay.LearningStrategy; > import org.apache.yay.NeuralNetwork; > import org.apache.yay.PredictionStrategy; > @@ -46,6 +48,7 @@ public class BackPropagationLearningStra > > private final PredictionStrategy<Double, Double> predictionStrategy; > private final CostFunction<RealMatrix, Double, Double> costFunction; > + private final DerivativeUpdateFunction<Double, Double> > derivativeUpdateFunction; > private final double alpha; > private final double threshold; > private final int batch; > @@ -63,6 +66,7 @@ public class BackPropagationLearningStra > this.alpha = alpha; > this.threshold = threshold; > this.batch = batch; > + this.derivativeUpdateFunction = new > DefaultDerivativeUpdateFunction(predictionStrategy); > } > > public BackPropagationLearningStrategy() { > @@ -72,6 +76,7 @@ public class BackPropagationLearningStra > this.alpha = DEFAULT_ALPHA; > this.threshold = DEFAULT_THRESHOLD; > this.batch = 1; > + this.derivativeUpdateFunction = new > DefaultDerivativeUpdateFunction(predictionStrategy); > } > > @Override > @@ -114,7 +119,7 @@ public class BackPropagationLearningStra > cost = newCost; > > // calculate the derivatives to update the parameters > - RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, > samples); > + RealMatrix[] derivatives = > derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples); > > // calculate the updated parameters > updatedWeights = updateWeights(weightsMatrixSet, derivatives, > alpha); > @@ -131,48 +136,6 @@ public class BackPropagationLearningStra > return updatedWeights; > } > > - private RealMatrix[] calculateDerivatives(RealMatrix[] > weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws > WeightLearningException { > - // set up the accumulator matrix(es) > - RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length]; > - RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length]; > - > - int noOfMatrixes = weightsMatrixSet.length - 1; > - double count = 0; > - for (TrainingExample<Double, Double> trainingExample : > trainingExamples) { > - try { > - // get activations from feed forward propagation > - RealVector[] activations = > predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), > weightsMatrixSet); > - > - // calculate output error (corresponding to the last delta^l) > - RealVector nextLayerDelta = calculateOutputError(trainingExample, > activations); > - > - deltaVectors[noOfMatrixes] = nextLayerDelta; > - > - // back prop the error and update the deltas accordingly > - for (int l = noOfMatrixes; l > 0; l--) { > - RealVector currentActivationsVector = activations[l - 1]; > - nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], > currentActivationsVector, nextLayerDelta); > - > - // collect delta vectors for this example > - deltaVectors[l - 1] = nextLayerDelta; > - } > - > - RealVector[] newActivations = new RealVector[activations.length]; > - newActivations[0] = > ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); > - System.arraycopy(activations, 0, newActivations, 1, > activations.length - 1); > - > - // update triangle (big delta matrix) > - updateTriangle(triangle, newActivations, deltaVectors, > weightsMatrixSet); > - > - } catch (Exception e) { > - throw new WeightLearningException("error during derivatives > calculation", e); > - } > - count++; > - } > - > - return createDerivatives(triangle, count); > - } > - > private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet, > RealMatrix[] derivatives, double alpha) { > RealMatrix[] updatedParameters = new > RealMatrix[weightsMatrixSet.length]; > for (int l = 0; l < weightsMatrixSet.length; l++) { > @@ -187,48 +150,4 @@ public class BackPropagationLearningStra > return updatedParameters; > } > > - private RealMatrix[] createDerivatives(RealMatrix[] triangle, double > count) { > - RealMatrix[] derivatives = new RealMatrix[triangle.length]; > - for (int i = 0; i < triangle.length; i++) { > - // TODO : introduce regularization diversification on bias term > (currently not regularized) > - derivatives[i] = triangle[i].scalarMultiply(1d / count); > - } > - return derivatives; > - } > - > - private void updateTriangle(RealMatrix[] triangle, RealVector[] > activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) { > - for (int l = weightsMatrixSet.length - 1; l >= 0; l--) { > - RealMatrix realMatrix = > deltaVectors[l].outerProduct(activations[l]); > - if (triangle[l] == null) { > - triangle[l] = realMatrix; > - } else { > - triangle[l] = triangle[l].add(realMatrix); > - } > - } > - } > - > - private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector > activationsVector, RealVector nextLayerDelta) { > - // TODO : remove the bias term from the error calculations > - ArrayRealVector identity = new > ArrayRealVector(activationsVector.getDimension(), 1d); > - RealVector gz = > activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = > a^l .* (1-a^l) > - return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz); > - } > - > - private RealVector calculateOutputError(TrainingExample<Double, Double> > trainingExample, RealVector[] activations) { > - RealVector output = activations[activations.length - 1]; > - > - Double[] sampleOutput = new Double[output.getDimension()]; > - int sampleOutputIntValue = trainingExample.getOutput().intValue(); > - if (sampleOutputIntValue < sampleOutput.length) { > - sampleOutput[sampleOutputIntValue] = 1d; > - } else if (sampleOutput.length == 1) { > - sampleOutput[0] = trainingExample.getOutput(); > - } else { > - throw new RuntimeException("problem with multiclass output > mapping"); > - } > - RealVector learnedOutputRealVector = new > ArrayRealVector(sampleOutput); // turn example output to a vector > - > - // TODO : improve error calculation -> this could be er_a = out_a * > (1 - out_a) * (tgt_a - out_a) > - return output.subtract(learnedOutputRealVector); > - } > } > > Modified: > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > ============================================================================== > --- > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > (original) > +++ > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > Mon Sep 28 16:49:57 2015 > @@ -94,4 +94,13 @@ public class BasicPerceptron implements > return > perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( > new Double[input.getFeatures().size()])); > } > + > + @Override > + public Double[] getOutputVector(Input<Double> input) throws > PredictionException { > + Double elaborate = > perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( > + new Double[input.getFeatures().size()])); > + Double[] ar = new Double[1]; > + ar[0] = elaborate; > + return ar; > + } > } > > Added: > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto > > ============================================================================== > --- > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > (added) > +++ > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > Mon Sep 28 16:49:57 2015 > @@ -0,0 +1,128 @@ > +/* > + * Licensed to the Apache Software Foundation (ASF) under one > + * or more contributor license agreements. See the NOTICE file > + * distributed with this work for additional information > + * regarding copyright ownership. The ASF licenses this file > + * to you under the Apache License, Version 2.0 (the > + * "License"); you may not use this file except in compliance > + * with the License. You may obtain a copy of the License at > + * > + * http://www.apache.org/licenses/LICENSE-2.0 > + * > + * Unless required by applicable law or agreed to in writing, > + * software distributed under the License is distributed on an > + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY > + * KIND, either express or implied. See the License for the > + * specific language governing permissions and limitations > + * under the License. > + */ > +package org.apache.yay.core; > + > +import org.apache.commons.math3.linear.ArrayRealVector; > +import org.apache.commons.math3.linear.RealMatrix; > +import org.apache.commons.math3.linear.RealVector; > +import org.apache.yay.DerivativeUpdateFunction; > +import org.apache.yay.PredictionStrategy; > +import org.apache.yay.TrainingExample; > +import org.apache.yay.TrainingSet; > +import org.apache.yay.core.utils.ConversionUtils; > + > +/** > + * Default derivatives update function > + */ > +public class DefaultDerivativeUpdateFunction implements > DerivativeUpdateFunction<Double, Double> { > + > + private final PredictionStrategy<Double, Double> predictionStrategy; > + > + public DefaultDerivativeUpdateFunction(PredictionStrategy<Double, > Double> predictionStrategy) { > + this.predictionStrategy = predictionStrategy; > + } > + > + @Override > + public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet, > TrainingSet<Double, Double> trainingExamples) { > + // set up the accumulator matrix(es) > + RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length]; > + RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length]; > + > + int noOfMatrixes = weightsMatrixSet.length - 1; > + double count = 0; > + for (TrainingExample<Double, Double> trainingExample : > trainingExamples) { > + try { > + // get activations from feed forward propagation > + RealVector[] activations = > predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), > weightsMatrixSet); > + > + // calculate output error (corresponding to the last delta^l) > + RealVector nextLayerDelta = calculateOutputError(trainingExample, > activations); > + > + deltaVectors[noOfMatrixes] = nextLayerDelta; > + > + // back prop the error and update the deltas accordingly > + for (int l = noOfMatrixes; l > 0; l--) { > + RealVector currentActivationsVector = activations[l - 1]; > + nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], > currentActivationsVector, nextLayerDelta); > + > + // collect delta vectors for this example > + deltaVectors[l - 1] = nextLayerDelta; > + } > + > + RealVector[] newActivations = new RealVector[activations.length]; > + newActivations[0] = > ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); > + System.arraycopy(activations, 0, newActivations, 1, > activations.length - 1); > + > + // update triangle (big delta matrix) > + updateTriangle(triangle, newActivations, deltaVectors, > weightsMatrixSet); > + > + } catch (Exception e) { > + throw new RuntimeException("error during derivatives > calculation", e); > + } > + count++; > + } > + > + return createDerivatives(triangle, count); > + } > + > + private RealMatrix[] createDerivatives(RealMatrix[] triangle, double > count) { > + RealMatrix[] derivatives = new RealMatrix[triangle.length]; > + for (int i = 0; i < triangle.length; i++) { > + // TODO : introduce regularization diversification on bias term > (currently not regularized) > + derivatives[i] = triangle[i].scalarMultiply(1d / count); > + } > + return derivatives; > + } > + > + private void updateTriangle(RealMatrix[] triangle, RealVector[] > activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) { > + for (int l = weightsMatrixSet.length - 1; l >= 0; l--) { > + RealMatrix realMatrix = > deltaVectors[l].outerProduct(activations[l]); > + if (triangle[l] == null) { > + triangle[l] = realMatrix; > + } else { > + triangle[l] = triangle[l].add(realMatrix); > + } > + } > + } > + > + private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector > activationsVector, RealVector nextLayerDelta) { > + // TODO : remove the bias term from the error calculations > + ArrayRealVector identity = new > ArrayRealVector(activationsVector.getDimension(), 1d); > + RealVector gz = > activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = > a^l .* (1-a^l) > + return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz); > + } > + > + private RealVector calculateOutputError(TrainingExample<Double, Double> > trainingExample, RealVector[] activations) { > + RealVector output = activations[activations.length - 1]; > + > + Double[] sampleOutput = new Double[output.getDimension()]; > + int sampleOutputIntValue = trainingExample.getOutput().intValue(); > + if (sampleOutputIntValue < sampleOutput.length) { > + sampleOutput[sampleOutputIntValue] = 1d; > + } else if (sampleOutput.length == 1) { > + sampleOutput[0] = trainingExample.getOutput(); > + } else { > + throw new RuntimeException("problem with multiclass output > mapping"); > + } > + RealVector learnedOutputRealVector = new > ArrayRealVector(sampleOutput); // turn example output to a vector > + > + // TODO : improve error calculation -> this could be er_a = out_a * > (1 - out_a) * (tgt_a - out_a) > + return output.subtract(learnedOutputRealVector); > + } > +} > > Modified: > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > ============================================================================== > --- > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > (original) > +++ > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > Mon Sep 28 16:49:57 2015 > @@ -16,7 +16,6 @@ > * specific language governing permissions and limitations > * under the License. > */ > - > package org.apache.yay.core; > > import java.util.ArrayList; > > Modified: > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > ============================================================================== > --- > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > (original) > +++ > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > Mon Sep 28 16:49:57 2015 > @@ -18,9 +18,11 @@ > */ > package org.apache.yay.core; > > +import java.util.ArrayList; > import java.util.Arrays; > import java.util.Collection; > import org.apache.commons.math3.linear.RealMatrix; > +import org.apache.commons.math3.linear.RealVector; > import org.apache.yay.CreationException; > import org.apache.yay.Input; > import org.apache.yay.LearningException; > @@ -53,6 +55,12 @@ public class NeuralNetworkFactory { > final > SelectionFunction<Collection<Double>, Double> selectionFunction) throws > CreationException { > return new NeuralNetwork() { > > + @Override > + public Double[] getOutputVector(Input<Double> input) throws > PredictionException { > + Collection<Double> inputVector = > ConversionUtils.toValuesCollection(input.getFeatures()); > + return predictionStrategy.predictOutput(inputVector, > updatedRealMatrixSet); > + } > + > private RealMatrix[] updatedRealMatrixSet = realMatrixSet; > > @Override > @@ -77,8 +85,7 @@ public class NeuralNetworkFactory { > @Override > public Double predict(Input<Double> input) throws > PredictionException { > try { > - Collection<Double> inputVector = > ConversionUtils.toValuesCollection(input.getFeatures()); > - Double[] doubles = > predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet); > + Double[] doubles = getOutputVector(input); > return selectionFunction.selectOutput(Arrays.asList(doubles)); > } catch (Exception e) { > throw new PredictionException(e); > > Modified: > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > ============================================================================== > --- > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > (original) > +++ > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > Mon Sep 28 16:49:57 2015 > @@ -19,6 +19,8 @@ > package org.apache.yay.core.utils; > > import java.util.ArrayList; > +import java.util.Collection; > + > import org.apache.yay.Feature; > import org.apache.yay.Input; > import org.apache.yay.TrainingExample; > @@ -41,6 +43,21 @@ public class ExamplesFactory { > return output; > } > }; > + } > + > + public static TrainingExample<Double, Collection<Double[]>> > createSGMExample(final Collection<Double[]> output, > + > final Double... featuresValues) { > + return new TrainingExample<Double, Collection<Double[]>>() { > + @Override > + public ArrayList<Feature<Double>> getFeatures() { > + return doublesToFeatureVector(featuresValues); > + } > + > + @Override > + public Collection<Double[]> getOutput() { > + return output; > + } > + }; > } > > public static Input<Double> createDoubleInput(final Double... > featuresValues) { > > Modified: > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > URL: > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > ============================================================================== > --- > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > (original) > +++ > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > Mon Sep 28 16:49:57 2015 > @@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest { > public void sampleCreationTest() throws Exception { > RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d, > 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}}); > RealMatrix secondLayer = new Array2DRowRealMatrix(new double[][]{{1d, > 2d, 3d}}); > + > RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer}; > + > NeuralNetwork neuralNetwork = createFFNN(RealMatrixes); > + > Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, > 7d)); > assertEquals(1l, Math.round(prdictedValue)); > assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue); > > > > --------------------------------------------------------------------- > To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org > For additional commands, e-mail: commits-h...@labs.apache.org > >