Hi

send email to

labs-unsubscr...@labs.apache.org
commits-unsubscr...@labs.apache.org

and follow the instructions.

rgds
jan i

On Monday, 28 September 2015, brandie serignet <bserig...@gmail.com> wrote:

> I need help. I want out
> On Sep 28, 2015 11:50 AM, <tomm...@apache.org <javascript:;>> wrote:
>
> > Author: tommaso
> > Date: Mon Sep 28 16:49:57 2015
> > New Revision: 1705721
> >
> > URL: http://svn.apache.org/viewvc?rev=1705721&view=rev
> > Log:
> > switch from batch to stochastic GD in backprop, abstracted derivative
> > update function
> >
> > Added:
> >
> >
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
> >
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> > Removed:
> >
>  labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
> > Modified:
> >     labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
> >
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
> >
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
> >
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
> >
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
> >
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
> >
> >
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
> >
> > Added:
> >
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
> > (added)
> > +++
> >
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
> > Mon Sep 28 16:49:57 2015
> > @@ -0,0 +1,30 @@
> > +/*
> > + * Licensed to the Apache Software Foundation (ASF) under one
> > + * or more contributor license agreements.  See the NOTICE file
> > + * distributed with this work for additional information
> > + * regarding copyright ownership.  The ASF licenses this file
> > + * to you under the Apache License, Version 2.0 (the
> > + * "License"); you may not use this file except in compliance
> > + * with the License.  You may obtain a copy of the License at
> > + *
> > + *  http://www.apache.org/licenses/LICENSE-2.0
> > + *
> > + * Unless required by applicable law or agreed to in writing,
> > + * software distributed under the License is distributed on an
> > + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
> > + * KIND, either express or implied.  See the License for the
> > + * specific language governing permissions and limitations
> > + * under the License.
> > + */
> > +package org.apache.yay;
> > +
> > +import org.apache.commons.math3.linear.RealMatrix;
> > +import org.apache.yay.TrainingSet;
> > +
> > +/**
> > + * Derivatives update function
> > + */
> > +public interface DerivativeUpdateFunction<F,O> {
> > +
> > +  RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O>
> > trainingSet);
> > +}
> >
> > Modified:
> > labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff
> >
> >
> ==============================================================================
> > --- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
> > (original)
> > +++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
> Mon
> > Sep 28 16:49:57 2015
> > @@ -21,8 +21,19 @@ package org.apache.yay;
> >  import org.apache.commons.math3.linear.RealMatrix;
> >
> >  /**
> > - * A neural network is a layered connected graph of elaboration units
> > + * A Neural Network is a layered connected graph of elaboration units.
> > + *
> > + * It takes a double vector as input and produces a double vector as
> > output.
> >   */
> >  public interface NeuralNetwork extends Hypothesis<RealMatrix, Double,
> > Double> {
> >
> > +  /**
> > +   * Predict the output for a given input
> > +   *
> > +   * @param input the input to evaluate
> > +   * @return the predicted output
> > +   * @throws PredictionException if any error occurs during the
> > prediction phase
> > +   */
> > +  Double[] getOutputVector(Input<Double> input) throws
> > PredictionException;
> > +
> >  }
> >
> > Modified:
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
> > (original)
> > +++
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
> > Mon Sep 28 16:49:57 2015
> > @@ -19,6 +19,7 @@
> >  package org.apache.yay.core;
> >
> >  import java.util.Arrays;
> > +import java.util.DoubleSummaryStatistics;
> >  import java.util.Iterator;
> >
> >  import org.apache.commons.math3.linear.Array2DRowRealMatrix;
> > @@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A
> >  import org.apache.commons.math3.linear.RealMatrix;
> >  import org.apache.commons.math3.linear.RealVector;
> >  import org.apache.yay.CostFunction;
> > +import org.apache.yay.DerivativeUpdateFunction;
> >  import org.apache.yay.LearningStrategy;
> >  import org.apache.yay.NeuralNetwork;
> >  import org.apache.yay.PredictionStrategy;
> > @@ -46,6 +48,7 @@ public class BackPropagationLearningStra
> >
> >    private final PredictionStrategy<Double, Double> predictionStrategy;
> >    private final CostFunction<RealMatrix, Double, Double> costFunction;
> > +  private final DerivativeUpdateFunction<Double, Double>
> > derivativeUpdateFunction;
> >    private final double alpha;
> >    private final double threshold;
> >    private final int batch;
> > @@ -63,6 +66,7 @@ public class BackPropagationLearningStra
> >      this.alpha = alpha;
> >      this.threshold = threshold;
> >      this.batch = batch;
> > +    this.derivativeUpdateFunction = new
> > DefaultDerivativeUpdateFunction(predictionStrategy);
> >    }
> >
> >    public BackPropagationLearningStrategy() {
> > @@ -72,6 +76,7 @@ public class BackPropagationLearningStra
> >      this.alpha = DEFAULT_ALPHA;
> >      this.threshold = DEFAULT_THRESHOLD;
> >      this.batch = 1;
> > +    this.derivativeUpdateFunction = new
> > DefaultDerivativeUpdateFunction(predictionStrategy);
> >    }
> >
> >    @Override
> > @@ -114,7 +119,7 @@ public class BackPropagationLearningStra
> >          cost = newCost;
> >
> >          // calculate the derivatives to update the parameters
> > -        RealMatrix[] derivatives =
> calculateDerivatives(weightsMatrixSet,
> > samples);
> > +        RealMatrix[] derivatives =
> > derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples);
> >
> >          // calculate the updated parameters
> >          updatedWeights = updateWeights(weightsMatrixSet, derivatives,
> > alpha);
> > @@ -131,48 +136,6 @@ public class BackPropagationLearningStra
> >      return updatedWeights;
> >    }
> >
> > -  private RealMatrix[] calculateDerivatives(RealMatrix[]
> > weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws
> > WeightLearningException {
> > -    // set up the accumulator matrix(es)
> > -    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
> > -    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
> > -
> > -    int noOfMatrixes = weightsMatrixSet.length - 1;
> > -    double count = 0;
> > -    for (TrainingExample<Double, Double> trainingExample :
> > trainingExamples) {
> > -      try {
> > -        // get activations from feed forward propagation
> > -        RealVector[] activations =
> >
> predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
> > weightsMatrixSet);
> > -
> > -        // calculate output error (corresponding to the last delta^l)
> > -        RealVector nextLayerDelta =
> calculateOutputError(trainingExample,
> > activations);
> > -
> > -        deltaVectors[noOfMatrixes] = nextLayerDelta;
> > -
> > -        // back prop the error and update the deltas accordingly
> > -        for (int l = noOfMatrixes; l > 0; l--) {
> > -          RealVector currentActivationsVector = activations[l - 1];
> > -          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l],
> > currentActivationsVector, nextLayerDelta);
> > -
> > -          // collect delta vectors for this example
> > -          deltaVectors[l - 1] = nextLayerDelta;
> > -        }
> > -
> > -        RealVector[] newActivations = new
> RealVector[activations.length];
> > -        newActivations[0] =
> >
> ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
> > -        System.arraycopy(activations, 0, newActivations, 1,
> > activations.length - 1);
> > -
> > -        // update triangle (big delta matrix)
> > -        updateTriangle(triangle, newActivations, deltaVectors,
> > weightsMatrixSet);
> > -
> > -      } catch (Exception e) {
> > -        throw new WeightLearningException("error during derivatives
> > calculation", e);
> > -      }
> > -      count++;
> > -    }
> > -
> > -    return createDerivatives(triangle, count);
> > -  }
> > -
> >    private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet,
> > RealMatrix[] derivatives, double alpha) {
> >      RealMatrix[] updatedParameters = new
> > RealMatrix[weightsMatrixSet.length];
> >      for (int l = 0; l < weightsMatrixSet.length; l++) {
> > @@ -187,48 +150,4 @@ public class BackPropagationLearningStra
> >      return updatedParameters;
> >    }
> >
> > -  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double
> > count) {
> > -    RealMatrix[] derivatives = new RealMatrix[triangle.length];
> > -    for (int i = 0; i < triangle.length; i++) {
> > -      // TODO : introduce regularization diversification on bias term
> > (currently not regularized)
> > -      derivatives[i] = triangle[i].scalarMultiply(1d / count);
> > -    }
> > -    return derivatives;
> > -  }
> > -
> > -  private void updateTriangle(RealMatrix[] triangle, RealVector[]
> > activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
> > -    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
> > -      RealMatrix realMatrix =
> > deltaVectors[l].outerProduct(activations[l]);
> > -      if (triangle[l] == null) {
> > -        triangle[l] = realMatrix;
> > -      } else {
> > -        triangle[l] = triangle[l].add(realMatrix);
> > -      }
> > -    }
> > -  }
> > -
> > -  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector
> > activationsVector, RealVector nextLayerDelta) {
> > -    // TODO : remove the bias term from the error calculations
> > -    ArrayRealVector identity = new
> > ArrayRealVector(activationsVector.getDimension(), 1d);
> > -    RealVector gz =
> > activationsVector.ebeMultiply(identity.subtract(activationsVector)); // =
> > a^l .* (1-a^l)
> > -    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
> > -  }
> > -
> > -  private RealVector calculateOutputError(TrainingExample<Double,
> Double>
> > trainingExample, RealVector[] activations) {
> > -    RealVector output = activations[activations.length - 1];
> > -
> > -    Double[] sampleOutput = new Double[output.getDimension()];
> > -    int sampleOutputIntValue = trainingExample.getOutput().intValue();
> > -    if (sampleOutputIntValue < sampleOutput.length) {
> > -      sampleOutput[sampleOutputIntValue] = 1d;
> > -    } else if (sampleOutput.length == 1) {
> > -      sampleOutput[0] = trainingExample.getOutput();
> > -    } else {
> > -      throw new RuntimeException("problem with multiclass output
> > mapping");
> > -    }
> > -    RealVector learnedOutputRealVector = new
> > ArrayRealVector(sampleOutput); // turn example output to a vector
> > -
> > -    // TODO : improve error calculation -> this could be er_a = out_a *
> > (1 - out_a) * (tgt_a - out_a)
> > -    return output.subtract(learnedOutputRealVector);
> > -  }
> >  }
> >
> > Modified:
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
> > (original)
> > +++
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
> > Mon Sep 28 16:49:57 2015
> > @@ -94,4 +94,13 @@ public class BasicPerceptron implements
> >      return
> >
> perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
> >              new Double[input.getFeatures().size()]));
> >    }
> > +
> > +  @Override
> > +  public Double[] getOutputVector(Input<Double> input) throws
> > PredictionException {
> > +    Double elaborate =
> >
> perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
> > +            new Double[input.getFeatures().size()]));
> > +    Double[] ar = new Double[1];
> > +    ar[0] = elaborate;
> > +    return ar;
> > +  }
> >  }
> >
> > Added:
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> > (added)
> > +++
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> > Mon Sep 28 16:49:57 2015
> > @@ -0,0 +1,128 @@
> > +/*
> > + * Licensed to the Apache Software Foundation (ASF) under one
> > + * or more contributor license agreements.  See the NOTICE file
> > + * distributed with this work for additional information
> > + * regarding copyright ownership.  The ASF licenses this file
> > + * to you under the Apache License, Version 2.0 (the
> > + * "License"); you may not use this file except in compliance
> > + * with the License.  You may obtain a copy of the License at
> > + *
> > + *  http://www.apache.org/licenses/LICENSE-2.0
> > + *
> > + * Unless required by applicable law or agreed to in writing,
> > + * software distributed under the License is distributed on an
> > + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
> > + * KIND, either express or implied.  See the License for the
> > + * specific language governing permissions and limitations
> > + * under the License.
> > + */
> > +package org.apache.yay.core;
> > +
> > +import org.apache.commons.math3.linear.ArrayRealVector;
> > +import org.apache.commons.math3.linear.RealMatrix;
> > +import org.apache.commons.math3.linear.RealVector;
> > +import org.apache.yay.DerivativeUpdateFunction;
> > +import org.apache.yay.PredictionStrategy;
> > +import org.apache.yay.TrainingExample;
> > +import org.apache.yay.TrainingSet;
> > +import org.apache.yay.core.utils.ConversionUtils;
> > +
> > +/**
> > + * Default derivatives update function
> > + */
> > +public class DefaultDerivativeUpdateFunction implements
> > DerivativeUpdateFunction<Double, Double> {
> > +
> > +  private final PredictionStrategy<Double, Double> predictionStrategy;
> > +
> > +  public DefaultDerivativeUpdateFunction(PredictionStrategy<Double,
> > Double> predictionStrategy) {
> > +    this.predictionStrategy = predictionStrategy;
> > +  }
> > +
> > +  @Override
> > +  public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet,
> > TrainingSet<Double, Double> trainingExamples) {
> > +    // set up the accumulator matrix(es)
> > +    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
> > +    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
> > +
> > +    int noOfMatrixes = weightsMatrixSet.length - 1;
> > +    double count = 0;
> > +    for (TrainingExample<Double, Double> trainingExample :
> > trainingExamples) {
> > +      try {
> > +        // get activations from feed forward propagation
> > +        RealVector[] activations =
> >
> predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
> > weightsMatrixSet);
> > +
> > +        // calculate output error (corresponding to the last delta^l)
> > +        RealVector nextLayerDelta =
> calculateOutputError(trainingExample,
> > activations);
> > +
> > +        deltaVectors[noOfMatrixes] = nextLayerDelta;
> > +
> > +        // back prop the error and update the deltas accordingly
> > +        for (int l = noOfMatrixes; l > 0; l--) {
> > +          RealVector currentActivationsVector = activations[l - 1];
> > +          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l],
> > currentActivationsVector, nextLayerDelta);
> > +
> > +          // collect delta vectors for this example
> > +          deltaVectors[l - 1] = nextLayerDelta;
> > +        }
> > +
> > +        RealVector[] newActivations = new
> RealVector[activations.length];
> > +        newActivations[0] =
> >
> ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
> > +        System.arraycopy(activations, 0, newActivations, 1,
> > activations.length - 1);
> > +
> > +        // update triangle (big delta matrix)
> > +        updateTriangle(triangle, newActivations, deltaVectors,
> > weightsMatrixSet);
> > +
> > +      } catch (Exception e) {
> > +        throw new RuntimeException("error during derivatives
> > calculation", e);
> > +      }
> > +      count++;
> > +    }
> > +
> > +    return createDerivatives(triangle, count);
> > +  }
> > +
> > +  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double
> > count) {
> > +    RealMatrix[] derivatives = new RealMatrix[triangle.length];
> > +    for (int i = 0; i < triangle.length; i++) {
> > +      // TODO : introduce regularization diversification on bias term
> > (currently not regularized)
> > +      derivatives[i] = triangle[i].scalarMultiply(1d / count);
> > +    }
> > +    return derivatives;
> > +  }
> > +
> > +  private void updateTriangle(RealMatrix[] triangle, RealVector[]
> > activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
> > +    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
> > +      RealMatrix realMatrix =
> > deltaVectors[l].outerProduct(activations[l]);
> > +      if (triangle[l] == null) {
> > +        triangle[l] = realMatrix;
> > +      } else {
> > +        triangle[l] = triangle[l].add(realMatrix);
> > +      }
> > +    }
> > +  }
> > +
> > +  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector
> > activationsVector, RealVector nextLayerDelta) {
> > +    // TODO : remove the bias term from the error calculations
> > +    ArrayRealVector identity = new
> > ArrayRealVector(activationsVector.getDimension(), 1d);
> > +    RealVector gz =
> > activationsVector.ebeMultiply(identity.subtract(activationsVector)); // =
> > a^l .* (1-a^l)
> > +    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
> > +  }
> > +
> > +  private RealVector calculateOutputError(TrainingExample<Double,
> Double>
> > trainingExample, RealVector[] activations) {
> > +    RealVector output = activations[activations.length - 1];
> > +
> > +    Double[] sampleOutput = new Double[output.getDimension()];
> > +    int sampleOutputIntValue = trainingExample.getOutput().intValue();
> > +    if (sampleOutputIntValue < sampleOutput.length) {
> > +      sampleOutput[sampleOutputIntValue] = 1d;
> > +    } else if (sampleOutput.length == 1) {
> > +      sampleOutput[0] = trainingExample.getOutput();
> > +    } else {
> > +      throw new RuntimeException("problem with multiclass output
> > mapping");
> > +    }
> > +    RealVector learnedOutputRealVector = new
> > ArrayRealVector(sampleOutput); // turn example output to a vector
> > +
> > +    // TODO : improve error calculation -> this could be er_a = out_a *
> > (1 - out_a) * (tgt_a - out_a)
> > +    return output.subtract(learnedOutputRealVector);
> > +  }
> > +}
> >
> > Modified:
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
> > (original)
> > +++
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
> > Mon Sep 28 16:49:57 2015
> > @@ -16,7 +16,6 @@
> >   * specific language governing permissions and limitations
> >   * under the License.
> >   */
> > -
> >  package org.apache.yay.core;
> >
> >  import java.util.ArrayList;
> >
> > Modified:
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
> > (original)
> > +++
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
> > Mon Sep 28 16:49:57 2015
> > @@ -18,9 +18,11 @@
> >   */
> >  package org.apache.yay.core;
> >
> > +import java.util.ArrayList;
> >  import java.util.Arrays;
> >  import java.util.Collection;
> >  import org.apache.commons.math3.linear.RealMatrix;
> > +import org.apache.commons.math3.linear.RealVector;
> >  import org.apache.yay.CreationException;
> >  import org.apache.yay.Input;
> >  import org.apache.yay.LearningException;
> > @@ -53,6 +55,12 @@ public class NeuralNetworkFactory {
> >                                       final
> > SelectionFunction<Collection<Double>, Double> selectionFunction) throws
> > CreationException {
> >      return new NeuralNetwork() {
> >
> > +      @Override
> > +      public Double[] getOutputVector(Input<Double> input) throws
> > PredictionException {
> > +        Collection<Double> inputVector =
> > ConversionUtils.toValuesCollection(input.getFeatures());
> > +        return predictionStrategy.predictOutput(inputVector,
> > updatedRealMatrixSet);
> > +      }
> > +
> >        private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
> >
> >        @Override
> > @@ -77,8 +85,7 @@ public class NeuralNetworkFactory {
> >        @Override
> >        public Double predict(Input<Double> input) throws
> > PredictionException {
> >          try {
> > -          Collection<Double> inputVector =
> > ConversionUtils.toValuesCollection(input.getFeatures());
> > -          Double[] doubles =
> > predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
> > +          Double[] doubles = getOutputVector(input);
> >            return selectionFunction.selectOutput(Arrays.asList(doubles));
> >          } catch (Exception e) {
> >            throw new PredictionException(e);
> >
> > Modified:
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
> > (original)
> > +++
> >
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
> > Mon Sep 28 16:49:57 2015
> > @@ -19,6 +19,8 @@
> >  package org.apache.yay.core.utils;
> >
> >  import java.util.ArrayList;
> > +import java.util.Collection;
> > +
> >  import org.apache.yay.Feature;
> >  import org.apache.yay.Input;
> >  import org.apache.yay.TrainingExample;
> > @@ -41,6 +43,21 @@ public class ExamplesFactory {
> >          return output;
> >        }
> >      };
> > +  }
> > +
> > +  public static TrainingExample<Double, Collection<Double[]>>
> > createSGMExample(final Collection<Double[]> output,
> > +
> >   final Double... featuresValues) {
> > +    return new TrainingExample<Double, Collection<Double[]>>() {
> > +      @Override
> > +      public ArrayList<Feature<Double>> getFeatures() {
> > +        return doublesToFeatureVector(featuresValues);
> > +      }
> > +
> > +      @Override
> > +      public Collection<Double[]> getOutput() {
> > +        return output;
> > +      }
> > +    };
> >    }
> >
> >    public static Input<Double> createDoubleInput(final Double...
> > featuresValues) {
> >
> > Modified:
> >
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
> > URL:
> >
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff
> >
> >
> ==============================================================================
> > ---
> >
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
> > (original)
> > +++
> >
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
> > Mon Sep 28 16:49:57 2015
> > @@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest {
> >    public void sampleCreationTest() throws Exception {
> >      RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d,
> > 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}});
> >      RealMatrix secondLayer = new Array2DRowRealMatrix(new
> double[][]{{1d,
> > 2d, 3d}});
> > +
> >      RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer,
> secondLayer};
> > +
> >      NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
> > +
> >      Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d,
> > 7d));
> >      assertEquals(1l, Math.round(prdictedValue));
> >      assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
> >
> >
> >
> > ---------------------------------------------------------------------
> > To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
> <javascript:;>
> > For additional commands, e-mail: commits-h...@labs.apache.org
> <javascript:;>
> >
> >
>


-- 
Sent from My iPad, sorry for any misspellings.

Reply via email to