Hi send email to
labs-unsubscr...@labs.apache.org commits-unsubscr...@labs.apache.org and follow the instructions. rgds jan i On Monday, 28 September 2015, brandie serignet <bserig...@gmail.com> wrote: > I need help. I want out > On Sep 28, 2015 11:50 AM, <tomm...@apache.org <javascript:;>> wrote: > > > Author: tommaso > > Date: Mon Sep 28 16:49:57 2015 > > New Revision: 1705721 > > > > URL: http://svn.apache.org/viewvc?rev=1705721&view=rev > > Log: > > switch from batch to stochastic GD in backprop, abstracted derivative > > update function > > > > Added: > > > > > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > > > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > > Removed: > > > labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java > > Modified: > > labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java > > > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > > > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > > > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > > > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > > > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > > > > > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > > > > Added: > > > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto > > > > > ============================================================================== > > --- > > > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > > (added) > > +++ > > > labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java > > Mon Sep 28 16:49:57 2015 > > @@ -0,0 +1,30 @@ > > +/* > > + * Licensed to the Apache Software Foundation (ASF) under one > > + * or more contributor license agreements. See the NOTICE file > > + * distributed with this work for additional information > > + * regarding copyright ownership. The ASF licenses this file > > + * to you under the Apache License, Version 2.0 (the > > + * "License"); you may not use this file except in compliance > > + * with the License. You may obtain a copy of the License at > > + * > > + * http://www.apache.org/licenses/LICENSE-2.0 > > + * > > + * Unless required by applicable law or agreed to in writing, > > + * software distributed under the License is distributed on an > > + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY > > + * KIND, either express or implied. See the License for the > > + * specific language governing permissions and limitations > > + * under the License. > > + */ > > +package org.apache.yay; > > + > > +import org.apache.commons.math3.linear.RealMatrix; > > +import org.apache.yay.TrainingSet; > > + > > +/** > > + * Derivatives update function > > + */ > > +public interface DerivativeUpdateFunction<F,O> { > > + > > + RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O> > > trainingSet); > > +} > > > > Modified: > > labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > > > > ============================================================================== > > --- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java > > (original) > > +++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java > Mon > > Sep 28 16:49:57 2015 > > @@ -21,8 +21,19 @@ package org.apache.yay; > > import org.apache.commons.math3.linear.RealMatrix; > > > > /** > > - * A neural network is a layered connected graph of elaboration units > > + * A Neural Network is a layered connected graph of elaboration units. > > + * > > + * It takes a double vector as input and produces a double vector as > > output. > > */ > > public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, > > Double> { > > > > + /** > > + * Predict the output for a given input > > + * > > + * @param input the input to evaluate > > + * @return the predicted output > > + * @throws PredictionException if any error occurs during the > > prediction phase > > + */ > > + Double[] getOutputVector(Input<Double> input) throws > > PredictionException; > > + > > } > > > > Modified: > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > > > > ============================================================================== > > --- > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > > (original) > > +++ > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java > > Mon Sep 28 16:49:57 2015 > > @@ -19,6 +19,7 @@ > > package org.apache.yay.core; > > > > import java.util.Arrays; > > +import java.util.DoubleSummaryStatistics; > > import java.util.Iterator; > > > > import org.apache.commons.math3.linear.Array2DRowRealMatrix; > > @@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A > > import org.apache.commons.math3.linear.RealMatrix; > > import org.apache.commons.math3.linear.RealVector; > > import org.apache.yay.CostFunction; > > +import org.apache.yay.DerivativeUpdateFunction; > > import org.apache.yay.LearningStrategy; > > import org.apache.yay.NeuralNetwork; > > import org.apache.yay.PredictionStrategy; > > @@ -46,6 +48,7 @@ public class BackPropagationLearningStra > > > > private final PredictionStrategy<Double, Double> predictionStrategy; > > private final CostFunction<RealMatrix, Double, Double> costFunction; > > + private final DerivativeUpdateFunction<Double, Double> > > derivativeUpdateFunction; > > private final double alpha; > > private final double threshold; > > private final int batch; > > @@ -63,6 +66,7 @@ public class BackPropagationLearningStra > > this.alpha = alpha; > > this.threshold = threshold; > > this.batch = batch; > > + this.derivativeUpdateFunction = new > > DefaultDerivativeUpdateFunction(predictionStrategy); > > } > > > > public BackPropagationLearningStrategy() { > > @@ -72,6 +76,7 @@ public class BackPropagationLearningStra > > this.alpha = DEFAULT_ALPHA; > > this.threshold = DEFAULT_THRESHOLD; > > this.batch = 1; > > + this.derivativeUpdateFunction = new > > DefaultDerivativeUpdateFunction(predictionStrategy); > > } > > > > @Override > > @@ -114,7 +119,7 @@ public class BackPropagationLearningStra > > cost = newCost; > > > > // calculate the derivatives to update the parameters > > - RealMatrix[] derivatives = > calculateDerivatives(weightsMatrixSet, > > samples); > > + RealMatrix[] derivatives = > > derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples); > > > > // calculate the updated parameters > > updatedWeights = updateWeights(weightsMatrixSet, derivatives, > > alpha); > > @@ -131,48 +136,6 @@ public class BackPropagationLearningStra > > return updatedWeights; > > } > > > > - private RealMatrix[] calculateDerivatives(RealMatrix[] > > weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws > > WeightLearningException { > > - // set up the accumulator matrix(es) > > - RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length]; > > - RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length]; > > - > > - int noOfMatrixes = weightsMatrixSet.length - 1; > > - double count = 0; > > - for (TrainingExample<Double, Double> trainingExample : > > trainingExamples) { > > - try { > > - // get activations from feed forward propagation > > - RealVector[] activations = > > > predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), > > weightsMatrixSet); > > - > > - // calculate output error (corresponding to the last delta^l) > > - RealVector nextLayerDelta = > calculateOutputError(trainingExample, > > activations); > > - > > - deltaVectors[noOfMatrixes] = nextLayerDelta; > > - > > - // back prop the error and update the deltas accordingly > > - for (int l = noOfMatrixes; l > 0; l--) { > > - RealVector currentActivationsVector = activations[l - 1]; > > - nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], > > currentActivationsVector, nextLayerDelta); > > - > > - // collect delta vectors for this example > > - deltaVectors[l - 1] = nextLayerDelta; > > - } > > - > > - RealVector[] newActivations = new > RealVector[activations.length]; > > - newActivations[0] = > > > ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); > > - System.arraycopy(activations, 0, newActivations, 1, > > activations.length - 1); > > - > > - // update triangle (big delta matrix) > > - updateTriangle(triangle, newActivations, deltaVectors, > > weightsMatrixSet); > > - > > - } catch (Exception e) { > > - throw new WeightLearningException("error during derivatives > > calculation", e); > > - } > > - count++; > > - } > > - > > - return createDerivatives(triangle, count); > > - } > > - > > private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet, > > RealMatrix[] derivatives, double alpha) { > > RealMatrix[] updatedParameters = new > > RealMatrix[weightsMatrixSet.length]; > > for (int l = 0; l < weightsMatrixSet.length; l++) { > > @@ -187,48 +150,4 @@ public class BackPropagationLearningStra > > return updatedParameters; > > } > > > > - private RealMatrix[] createDerivatives(RealMatrix[] triangle, double > > count) { > > - RealMatrix[] derivatives = new RealMatrix[triangle.length]; > > - for (int i = 0; i < triangle.length; i++) { > > - // TODO : introduce regularization diversification on bias term > > (currently not regularized) > > - derivatives[i] = triangle[i].scalarMultiply(1d / count); > > - } > > - return derivatives; > > - } > > - > > - private void updateTriangle(RealMatrix[] triangle, RealVector[] > > activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) { > > - for (int l = weightsMatrixSet.length - 1; l >= 0; l--) { > > - RealMatrix realMatrix = > > deltaVectors[l].outerProduct(activations[l]); > > - if (triangle[l] == null) { > > - triangle[l] = realMatrix; > > - } else { > > - triangle[l] = triangle[l].add(realMatrix); > > - } > > - } > > - } > > - > > - private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector > > activationsVector, RealVector nextLayerDelta) { > > - // TODO : remove the bias term from the error calculations > > - ArrayRealVector identity = new > > ArrayRealVector(activationsVector.getDimension(), 1d); > > - RealVector gz = > > activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = > > a^l .* (1-a^l) > > - return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz); > > - } > > - > > - private RealVector calculateOutputError(TrainingExample<Double, > Double> > > trainingExample, RealVector[] activations) { > > - RealVector output = activations[activations.length - 1]; > > - > > - Double[] sampleOutput = new Double[output.getDimension()]; > > - int sampleOutputIntValue = trainingExample.getOutput().intValue(); > > - if (sampleOutputIntValue < sampleOutput.length) { > > - sampleOutput[sampleOutputIntValue] = 1d; > > - } else if (sampleOutput.length == 1) { > > - sampleOutput[0] = trainingExample.getOutput(); > > - } else { > > - throw new RuntimeException("problem with multiclass output > > mapping"); > > - } > > - RealVector learnedOutputRealVector = new > > ArrayRealVector(sampleOutput); // turn example output to a vector > > - > > - // TODO : improve error calculation -> this could be er_a = out_a * > > (1 - out_a) * (tgt_a - out_a) > > - return output.subtract(learnedOutputRealVector); > > - } > > } > > > > Modified: > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > > > > ============================================================================== > > --- > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > > (original) > > +++ > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java > > Mon Sep 28 16:49:57 2015 > > @@ -94,4 +94,13 @@ public class BasicPerceptron implements > > return > > > perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( > > new Double[input.getFeatures().size()])); > > } > > + > > + @Override > > + public Double[] getOutputVector(Input<Double> input) throws > > PredictionException { > > + Double elaborate = > > > perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray( > > + new Double[input.getFeatures().size()])); > > + Double[] ar = new Double[1]; > > + ar[0] = elaborate; > > + return ar; > > + } > > } > > > > Added: > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto > > > > > ============================================================================== > > --- > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > > (added) > > +++ > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java > > Mon Sep 28 16:49:57 2015 > > @@ -0,0 +1,128 @@ > > +/* > > + * Licensed to the Apache Software Foundation (ASF) under one > > + * or more contributor license agreements. See the NOTICE file > > + * distributed with this work for additional information > > + * regarding copyright ownership. The ASF licenses this file > > + * to you under the Apache License, Version 2.0 (the > > + * "License"); you may not use this file except in compliance > > + * with the License. You may obtain a copy of the License at > > + * > > + * http://www.apache.org/licenses/LICENSE-2.0 > > + * > > + * Unless required by applicable law or agreed to in writing, > > + * software distributed under the License is distributed on an > > + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY > > + * KIND, either express or implied. See the License for the > > + * specific language governing permissions and limitations > > + * under the License. > > + */ > > +package org.apache.yay.core; > > + > > +import org.apache.commons.math3.linear.ArrayRealVector; > > +import org.apache.commons.math3.linear.RealMatrix; > > +import org.apache.commons.math3.linear.RealVector; > > +import org.apache.yay.DerivativeUpdateFunction; > > +import org.apache.yay.PredictionStrategy; > > +import org.apache.yay.TrainingExample; > > +import org.apache.yay.TrainingSet; > > +import org.apache.yay.core.utils.ConversionUtils; > > + > > +/** > > + * Default derivatives update function > > + */ > > +public class DefaultDerivativeUpdateFunction implements > > DerivativeUpdateFunction<Double, Double> { > > + > > + private final PredictionStrategy<Double, Double> predictionStrategy; > > + > > + public DefaultDerivativeUpdateFunction(PredictionStrategy<Double, > > Double> predictionStrategy) { > > + this.predictionStrategy = predictionStrategy; > > + } > > + > > + @Override > > + public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet, > > TrainingSet<Double, Double> trainingExamples) { > > + // set up the accumulator matrix(es) > > + RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length]; > > + RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length]; > > + > > + int noOfMatrixes = weightsMatrixSet.length - 1; > > + double count = 0; > > + for (TrainingExample<Double, Double> trainingExample : > > trainingExamples) { > > + try { > > + // get activations from feed forward propagation > > + RealVector[] activations = > > > predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()), > > weightsMatrixSet); > > + > > + // calculate output error (corresponding to the last delta^l) > > + RealVector nextLayerDelta = > calculateOutputError(trainingExample, > > activations); > > + > > + deltaVectors[noOfMatrixes] = nextLayerDelta; > > + > > + // back prop the error and update the deltas accordingly > > + for (int l = noOfMatrixes; l > 0; l--) { > > + RealVector currentActivationsVector = activations[l - 1]; > > + nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], > > currentActivationsVector, nextLayerDelta); > > + > > + // collect delta vectors for this example > > + deltaVectors[l - 1] = nextLayerDelta; > > + } > > + > > + RealVector[] newActivations = new > RealVector[activations.length]; > > + newActivations[0] = > > > ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures())); > > + System.arraycopy(activations, 0, newActivations, 1, > > activations.length - 1); > > + > > + // update triangle (big delta matrix) > > + updateTriangle(triangle, newActivations, deltaVectors, > > weightsMatrixSet); > > + > > + } catch (Exception e) { > > + throw new RuntimeException("error during derivatives > > calculation", e); > > + } > > + count++; > > + } > > + > > + return createDerivatives(triangle, count); > > + } > > + > > + private RealMatrix[] createDerivatives(RealMatrix[] triangle, double > > count) { > > + RealMatrix[] derivatives = new RealMatrix[triangle.length]; > > + for (int i = 0; i < triangle.length; i++) { > > + // TODO : introduce regularization diversification on bias term > > (currently not regularized) > > + derivatives[i] = triangle[i].scalarMultiply(1d / count); > > + } > > + return derivatives; > > + } > > + > > + private void updateTriangle(RealMatrix[] triangle, RealVector[] > > activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) { > > + for (int l = weightsMatrixSet.length - 1; l >= 0; l--) { > > + RealMatrix realMatrix = > > deltaVectors[l].outerProduct(activations[l]); > > + if (triangle[l] == null) { > > + triangle[l] = realMatrix; > > + } else { > > + triangle[l] = triangle[l].add(realMatrix); > > + } > > + } > > + } > > + > > + private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector > > activationsVector, RealVector nextLayerDelta) { > > + // TODO : remove the bias term from the error calculations > > + ArrayRealVector identity = new > > ArrayRealVector(activationsVector.getDimension(), 1d); > > + RealVector gz = > > activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = > > a^l .* (1-a^l) > > + return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz); > > + } > > + > > + private RealVector calculateOutputError(TrainingExample<Double, > Double> > > trainingExample, RealVector[] activations) { > > + RealVector output = activations[activations.length - 1]; > > + > > + Double[] sampleOutput = new Double[output.getDimension()]; > > + int sampleOutputIntValue = trainingExample.getOutput().intValue(); > > + if (sampleOutputIntValue < sampleOutput.length) { > > + sampleOutput[sampleOutputIntValue] = 1d; > > + } else if (sampleOutput.length == 1) { > > + sampleOutput[0] = trainingExample.getOutput(); > > + } else { > > + throw new RuntimeException("problem with multiclass output > > mapping"); > > + } > > + RealVector learnedOutputRealVector = new > > ArrayRealVector(sampleOutput); // turn example output to a vector > > + > > + // TODO : improve error calculation -> this could be er_a = out_a * > > (1 - out_a) * (tgt_a - out_a) > > + return output.subtract(learnedOutputRealVector); > > + } > > +} > > > > Modified: > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > > > > ============================================================================== > > --- > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > > (original) > > +++ > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java > > Mon Sep 28 16:49:57 2015 > > @@ -16,7 +16,6 @@ > > * specific language governing permissions and limitations > > * under the License. > > */ > > - > > package org.apache.yay.core; > > > > import java.util.ArrayList; > > > > Modified: > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > > > > ============================================================================== > > --- > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > > (original) > > +++ > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java > > Mon Sep 28 16:49:57 2015 > > @@ -18,9 +18,11 @@ > > */ > > package org.apache.yay.core; > > > > +import java.util.ArrayList; > > import java.util.Arrays; > > import java.util.Collection; > > import org.apache.commons.math3.linear.RealMatrix; > > +import org.apache.commons.math3.linear.RealVector; > > import org.apache.yay.CreationException; > > import org.apache.yay.Input; > > import org.apache.yay.LearningException; > > @@ -53,6 +55,12 @@ public class NeuralNetworkFactory { > > final > > SelectionFunction<Collection<Double>, Double> selectionFunction) throws > > CreationException { > > return new NeuralNetwork() { > > > > + @Override > > + public Double[] getOutputVector(Input<Double> input) throws > > PredictionException { > > + Collection<Double> inputVector = > > ConversionUtils.toValuesCollection(input.getFeatures()); > > + return predictionStrategy.predictOutput(inputVector, > > updatedRealMatrixSet); > > + } > > + > > private RealMatrix[] updatedRealMatrixSet = realMatrixSet; > > > > @Override > > @@ -77,8 +85,7 @@ public class NeuralNetworkFactory { > > @Override > > public Double predict(Input<Double> input) throws > > PredictionException { > > try { > > - Collection<Double> inputVector = > > ConversionUtils.toValuesCollection(input.getFeatures()); > > - Double[] doubles = > > predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet); > > + Double[] doubles = getOutputVector(input); > > return selectionFunction.selectOutput(Arrays.asList(doubles)); > > } catch (Exception e) { > > throw new PredictionException(e); > > > > Modified: > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > > > > ============================================================================== > > --- > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > > (original) > > +++ > > > labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java > > Mon Sep 28 16:49:57 2015 > > @@ -19,6 +19,8 @@ > > package org.apache.yay.core.utils; > > > > import java.util.ArrayList; > > +import java.util.Collection; > > + > > import org.apache.yay.Feature; > > import org.apache.yay.Input; > > import org.apache.yay.TrainingExample; > > @@ -41,6 +43,21 @@ public class ExamplesFactory { > > return output; > > } > > }; > > + } > > + > > + public static TrainingExample<Double, Collection<Double[]>> > > createSGMExample(final Collection<Double[]> output, > > + > > final Double... featuresValues) { > > + return new TrainingExample<Double, Collection<Double[]>>() { > > + @Override > > + public ArrayList<Feature<Double>> getFeatures() { > > + return doublesToFeatureVector(featuresValues); > > + } > > + > > + @Override > > + public Collection<Double[]> getOutput() { > > + return output; > > + } > > + }; > > } > > > > public static Input<Double> createDoubleInput(final Double... > > featuresValues) { > > > > Modified: > > > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > > URL: > > > http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff > > > > > ============================================================================== > > --- > > > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > > (original) > > +++ > > > labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java > > Mon Sep 28 16:49:57 2015 > > @@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest { > > public void sampleCreationTest() throws Exception { > > RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d, > > 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}}); > > RealMatrix secondLayer = new Array2DRowRealMatrix(new > double[][]{{1d, > > 2d, 3d}}); > > + > > RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, > secondLayer}; > > + > > NeuralNetwork neuralNetwork = createFFNN(RealMatrixes); > > + > > Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, > > 7d)); > > assertEquals(1l, Math.round(prdictedValue)); > > assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue); > > > > > > > > --------------------------------------------------------------------- > > To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org > <javascript:;> > > For additional commands, e-mail: commits-h...@labs.apache.org > <javascript:;> > > > > > -- Sent from My iPad, sorry for any misspellings.