I need help. I want out
On Sep 28, 2015 11:50 AM, <tomm...@apache.org> wrote:

> Author: tommaso
> Date: Mon Sep 28 16:49:57 2015
> New Revision: 1705721
>
> URL: http://svn.apache.org/viewvc?rev=1705721&view=rev
> Log:
> switch from batch to stochastic GD in backprop, abstracted derivative
> update function
>
> Added:
>
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
>
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> Removed:
>     labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
> Modified:
>     labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
>
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
>
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
>
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
>
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
>
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
>
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
>
> Added:
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto
>
> ==============================================================================
> ---
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
> (added)
> +++
> labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
> Mon Sep 28 16:49:57 2015
> @@ -0,0 +1,30 @@
> +/*
> + * Licensed to the Apache Software Foundation (ASF) under one
> + * or more contributor license agreements.  See the NOTICE file
> + * distributed with this work for additional information
> + * regarding copyright ownership.  The ASF licenses this file
> + * to you under the Apache License, Version 2.0 (the
> + * "License"); you may not use this file except in compliance
> + * with the License.  You may obtain a copy of the License at
> + *
> + *  http://www.apache.org/licenses/LICENSE-2.0
> + *
> + * Unless required by applicable law or agreed to in writing,
> + * software distributed under the License is distributed on an
> + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
> + * KIND, either express or implied.  See the License for the
> + * specific language governing permissions and limitations
> + * under the License.
> + */
> +package org.apache.yay;
> +
> +import org.apache.commons.math3.linear.RealMatrix;
> +import org.apache.yay.TrainingSet;
> +
> +/**
> + * Derivatives update function
> + */
> +public interface DerivativeUpdateFunction<F,O> {
> +
> +  RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O>
> trainingSet);
> +}
>
> Modified:
> labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff
>
> ==============================================================================
> --- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
> (original)
> +++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Mon
> Sep 28 16:49:57 2015
> @@ -21,8 +21,19 @@ package org.apache.yay;
>  import org.apache.commons.math3.linear.RealMatrix;
>
>  /**
> - * A neural network is a layered connected graph of elaboration units
> + * A Neural Network is a layered connected graph of elaboration units.
> + *
> + * It takes a double vector as input and produces a double vector as
> output.
>   */
>  public interface NeuralNetwork extends Hypothesis<RealMatrix, Double,
> Double> {
>
> +  /**
> +   * Predict the output for a given input
> +   *
> +   * @param input the input to evaluate
> +   * @return the predicted output
> +   * @throws PredictionException if any error occurs during the
> prediction phase
> +   */
> +  Double[] getOutputVector(Input<Double> input) throws
> PredictionException;
> +
>  }
>
> Modified:
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
>
> ==============================================================================
> ---
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
> (original)
> +++
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
> Mon Sep 28 16:49:57 2015
> @@ -19,6 +19,7 @@
>  package org.apache.yay.core;
>
>  import java.util.Arrays;
> +import java.util.DoubleSummaryStatistics;
>  import java.util.Iterator;
>
>  import org.apache.commons.math3.linear.Array2DRowRealMatrix;
> @@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A
>  import org.apache.commons.math3.linear.RealMatrix;
>  import org.apache.commons.math3.linear.RealVector;
>  import org.apache.yay.CostFunction;
> +import org.apache.yay.DerivativeUpdateFunction;
>  import org.apache.yay.LearningStrategy;
>  import org.apache.yay.NeuralNetwork;
>  import org.apache.yay.PredictionStrategy;
> @@ -46,6 +48,7 @@ public class BackPropagationLearningStra
>
>    private final PredictionStrategy<Double, Double> predictionStrategy;
>    private final CostFunction<RealMatrix, Double, Double> costFunction;
> +  private final DerivativeUpdateFunction<Double, Double>
> derivativeUpdateFunction;
>    private final double alpha;
>    private final double threshold;
>    private final int batch;
> @@ -63,6 +66,7 @@ public class BackPropagationLearningStra
>      this.alpha = alpha;
>      this.threshold = threshold;
>      this.batch = batch;
> +    this.derivativeUpdateFunction = new
> DefaultDerivativeUpdateFunction(predictionStrategy);
>    }
>
>    public BackPropagationLearningStrategy() {
> @@ -72,6 +76,7 @@ public class BackPropagationLearningStra
>      this.alpha = DEFAULT_ALPHA;
>      this.threshold = DEFAULT_THRESHOLD;
>      this.batch = 1;
> +    this.derivativeUpdateFunction = new
> DefaultDerivativeUpdateFunction(predictionStrategy);
>    }
>
>    @Override
> @@ -114,7 +119,7 @@ public class BackPropagationLearningStra
>          cost = newCost;
>
>          // calculate the derivatives to update the parameters
> -        RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet,
> samples);
> +        RealMatrix[] derivatives =
> derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples);
>
>          // calculate the updated parameters
>          updatedWeights = updateWeights(weightsMatrixSet, derivatives,
> alpha);
> @@ -131,48 +136,6 @@ public class BackPropagationLearningStra
>      return updatedWeights;
>    }
>
> -  private RealMatrix[] calculateDerivatives(RealMatrix[]
> weightsMatrixSet, TrainingSet<Double, Double> trainingExamples) throws
> WeightLearningException {
> -    // set up the accumulator matrix(es)
> -    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
> -    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
> -
> -    int noOfMatrixes = weightsMatrixSet.length - 1;
> -    double count = 0;
> -    for (TrainingExample<Double, Double> trainingExample :
> trainingExamples) {
> -      try {
> -        // get activations from feed forward propagation
> -        RealVector[] activations =
> predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
> weightsMatrixSet);
> -
> -        // calculate output error (corresponding to the last delta^l)
> -        RealVector nextLayerDelta = calculateOutputError(trainingExample,
> activations);
> -
> -        deltaVectors[noOfMatrixes] = nextLayerDelta;
> -
> -        // back prop the error and update the deltas accordingly
> -        for (int l = noOfMatrixes; l > 0; l--) {
> -          RealVector currentActivationsVector = activations[l - 1];
> -          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l],
> currentActivationsVector, nextLayerDelta);
> -
> -          // collect delta vectors for this example
> -          deltaVectors[l - 1] = nextLayerDelta;
> -        }
> -
> -        RealVector[] newActivations = new RealVector[activations.length];
> -        newActivations[0] =
> ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
> -        System.arraycopy(activations, 0, newActivations, 1,
> activations.length - 1);
> -
> -        // update triangle (big delta matrix)
> -        updateTriangle(triangle, newActivations, deltaVectors,
> weightsMatrixSet);
> -
> -      } catch (Exception e) {
> -        throw new WeightLearningException("error during derivatives
> calculation", e);
> -      }
> -      count++;
> -    }
> -
> -    return createDerivatives(triangle, count);
> -  }
> -
>    private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet,
> RealMatrix[] derivatives, double alpha) {
>      RealMatrix[] updatedParameters = new
> RealMatrix[weightsMatrixSet.length];
>      for (int l = 0; l < weightsMatrixSet.length; l++) {
> @@ -187,48 +150,4 @@ public class BackPropagationLearningStra
>      return updatedParameters;
>    }
>
> -  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double
> count) {
> -    RealMatrix[] derivatives = new RealMatrix[triangle.length];
> -    for (int i = 0; i < triangle.length; i++) {
> -      // TODO : introduce regularization diversification on bias term
> (currently not regularized)
> -      derivatives[i] = triangle[i].scalarMultiply(1d / count);
> -    }
> -    return derivatives;
> -  }
> -
> -  private void updateTriangle(RealMatrix[] triangle, RealVector[]
> activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
> -    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
> -      RealMatrix realMatrix =
> deltaVectors[l].outerProduct(activations[l]);
> -      if (triangle[l] == null) {
> -        triangle[l] = realMatrix;
> -      } else {
> -        triangle[l] = triangle[l].add(realMatrix);
> -      }
> -    }
> -  }
> -
> -  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector
> activationsVector, RealVector nextLayerDelta) {
> -    // TODO : remove the bias term from the error calculations
> -    ArrayRealVector identity = new
> ArrayRealVector(activationsVector.getDimension(), 1d);
> -    RealVector gz =
> activationsVector.ebeMultiply(identity.subtract(activationsVector)); // =
> a^l .* (1-a^l)
> -    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
> -  }
> -
> -  private RealVector calculateOutputError(TrainingExample<Double, Double>
> trainingExample, RealVector[] activations) {
> -    RealVector output = activations[activations.length - 1];
> -
> -    Double[] sampleOutput = new Double[output.getDimension()];
> -    int sampleOutputIntValue = trainingExample.getOutput().intValue();
> -    if (sampleOutputIntValue < sampleOutput.length) {
> -      sampleOutput[sampleOutputIntValue] = 1d;
> -    } else if (sampleOutput.length == 1) {
> -      sampleOutput[0] = trainingExample.getOutput();
> -    } else {
> -      throw new RuntimeException("problem with multiclass output
> mapping");
> -    }
> -    RealVector learnedOutputRealVector = new
> ArrayRealVector(sampleOutput); // turn example output to a vector
> -
> -    // TODO : improve error calculation -> this could be er_a = out_a *
> (1 - out_a) * (tgt_a - out_a)
> -    return output.subtract(learnedOutputRealVector);
> -  }
>  }
>
> Modified:
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff
>
> ==============================================================================
> ---
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
> (original)
> +++
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
> Mon Sep 28 16:49:57 2015
> @@ -94,4 +94,13 @@ public class BasicPerceptron implements
>      return
> perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
>              new Double[input.getFeatures().size()]));
>    }
> +
> +  @Override
> +  public Double[] getOutputVector(Input<Double> input) throws
> PredictionException {
> +    Double elaborate =
> perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
> +            new Double[input.getFeatures().size()]));
> +    Double[] ar = new Double[1];
> +    ar[0] = elaborate;
> +    return ar;
> +  }
>  }
>
> Added:
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto
>
> ==============================================================================
> ---
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> (added)
> +++
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
> Mon Sep 28 16:49:57 2015
> @@ -0,0 +1,128 @@
> +/*
> + * Licensed to the Apache Software Foundation (ASF) under one
> + * or more contributor license agreements.  See the NOTICE file
> + * distributed with this work for additional information
> + * regarding copyright ownership.  The ASF licenses this file
> + * to you under the Apache License, Version 2.0 (the
> + * "License"); you may not use this file except in compliance
> + * with the License.  You may obtain a copy of the License at
> + *
> + *  http://www.apache.org/licenses/LICENSE-2.0
> + *
> + * Unless required by applicable law or agreed to in writing,
> + * software distributed under the License is distributed on an
> + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
> + * KIND, either express or implied.  See the License for the
> + * specific language governing permissions and limitations
> + * under the License.
> + */
> +package org.apache.yay.core;
> +
> +import org.apache.commons.math3.linear.ArrayRealVector;
> +import org.apache.commons.math3.linear.RealMatrix;
> +import org.apache.commons.math3.linear.RealVector;
> +import org.apache.yay.DerivativeUpdateFunction;
> +import org.apache.yay.PredictionStrategy;
> +import org.apache.yay.TrainingExample;
> +import org.apache.yay.TrainingSet;
> +import org.apache.yay.core.utils.ConversionUtils;
> +
> +/**
> + * Default derivatives update function
> + */
> +public class DefaultDerivativeUpdateFunction implements
> DerivativeUpdateFunction<Double, Double> {
> +
> +  private final PredictionStrategy<Double, Double> predictionStrategy;
> +
> +  public DefaultDerivativeUpdateFunction(PredictionStrategy<Double,
> Double> predictionStrategy) {
> +    this.predictionStrategy = predictionStrategy;
> +  }
> +
> +  @Override
> +  public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet,
> TrainingSet<Double, Double> trainingExamples) {
> +    // set up the accumulator matrix(es)
> +    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
> +    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
> +
> +    int noOfMatrixes = weightsMatrixSet.length - 1;
> +    double count = 0;
> +    for (TrainingExample<Double, Double> trainingExample :
> trainingExamples) {
> +      try {
> +        // get activations from feed forward propagation
> +        RealVector[] activations =
> predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
> weightsMatrixSet);
> +
> +        // calculate output error (corresponding to the last delta^l)
> +        RealVector nextLayerDelta = calculateOutputError(trainingExample,
> activations);
> +
> +        deltaVectors[noOfMatrixes] = nextLayerDelta;
> +
> +        // back prop the error and update the deltas accordingly
> +        for (int l = noOfMatrixes; l > 0; l--) {
> +          RealVector currentActivationsVector = activations[l - 1];
> +          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l],
> currentActivationsVector, nextLayerDelta);
> +
> +          // collect delta vectors for this example
> +          deltaVectors[l - 1] = nextLayerDelta;
> +        }
> +
> +        RealVector[] newActivations = new RealVector[activations.length];
> +        newActivations[0] =
> ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
> +        System.arraycopy(activations, 0, newActivations, 1,
> activations.length - 1);
> +
> +        // update triangle (big delta matrix)
> +        updateTriangle(triangle, newActivations, deltaVectors,
> weightsMatrixSet);
> +
> +      } catch (Exception e) {
> +        throw new RuntimeException("error during derivatives
> calculation", e);
> +      }
> +      count++;
> +    }
> +
> +    return createDerivatives(triangle, count);
> +  }
> +
> +  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double
> count) {
> +    RealMatrix[] derivatives = new RealMatrix[triangle.length];
> +    for (int i = 0; i < triangle.length; i++) {
> +      // TODO : introduce regularization diversification on bias term
> (currently not regularized)
> +      derivatives[i] = triangle[i].scalarMultiply(1d / count);
> +    }
> +    return derivatives;
> +  }
> +
> +  private void updateTriangle(RealMatrix[] triangle, RealVector[]
> activations, RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
> +    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
> +      RealMatrix realMatrix =
> deltaVectors[l].outerProduct(activations[l]);
> +      if (triangle[l] == null) {
> +        triangle[l] = realMatrix;
> +      } else {
> +        triangle[l] = triangle[l].add(realMatrix);
> +      }
> +    }
> +  }
> +
> +  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector
> activationsVector, RealVector nextLayerDelta) {
> +    // TODO : remove the bias term from the error calculations
> +    ArrayRealVector identity = new
> ArrayRealVector(activationsVector.getDimension(), 1d);
> +    RealVector gz =
> activationsVector.ebeMultiply(identity.subtract(activationsVector)); // =
> a^l .* (1-a^l)
> +    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
> +  }
> +
> +  private RealVector calculateOutputError(TrainingExample<Double, Double>
> trainingExample, RealVector[] activations) {
> +    RealVector output = activations[activations.length - 1];
> +
> +    Double[] sampleOutput = new Double[output.getDimension()];
> +    int sampleOutputIntValue = trainingExample.getOutput().intValue();
> +    if (sampleOutputIntValue < sampleOutput.length) {
> +      sampleOutput[sampleOutputIntValue] = 1d;
> +    } else if (sampleOutput.length == 1) {
> +      sampleOutput[0] = trainingExample.getOutput();
> +    } else {
> +      throw new RuntimeException("problem with multiclass output
> mapping");
> +    }
> +    RealVector learnedOutputRealVector = new
> ArrayRealVector(sampleOutput); // turn example output to a vector
> +
> +    // TODO : improve error calculation -> this could be er_a = out_a *
> (1 - out_a) * (tgt_a - out_a)
> +    return output.subtract(learnedOutputRealVector);
> +  }
> +}
>
> Modified:
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
>
> ==============================================================================
> ---
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
> (original)
> +++
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
> Mon Sep 28 16:49:57 2015
> @@ -16,7 +16,6 @@
>   * specific language governing permissions and limitations
>   * under the License.
>   */
> -
>  package org.apache.yay.core;
>
>  import java.util.ArrayList;
>
> Modified:
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
>
> ==============================================================================
> ---
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
> (original)
> +++
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
> Mon Sep 28 16:49:57 2015
> @@ -18,9 +18,11 @@
>   */
>  package org.apache.yay.core;
>
> +import java.util.ArrayList;
>  import java.util.Arrays;
>  import java.util.Collection;
>  import org.apache.commons.math3.linear.RealMatrix;
> +import org.apache.commons.math3.linear.RealVector;
>  import org.apache.yay.CreationException;
>  import org.apache.yay.Input;
>  import org.apache.yay.LearningException;
> @@ -53,6 +55,12 @@ public class NeuralNetworkFactory {
>                                       final
> SelectionFunction<Collection<Double>, Double> selectionFunction) throws
> CreationException {
>      return new NeuralNetwork() {
>
> +      @Override
> +      public Double[] getOutputVector(Input<Double> input) throws
> PredictionException {
> +        Collection<Double> inputVector =
> ConversionUtils.toValuesCollection(input.getFeatures());
> +        return predictionStrategy.predictOutput(inputVector,
> updatedRealMatrixSet);
> +      }
> +
>        private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
>
>        @Override
> @@ -77,8 +85,7 @@ public class NeuralNetworkFactory {
>        @Override
>        public Double predict(Input<Double> input) throws
> PredictionException {
>          try {
> -          Collection<Double> inputVector =
> ConversionUtils.toValuesCollection(input.getFeatures());
> -          Double[] doubles =
> predictionStrategy.predictOutput(inputVector, updatedRealMatrixSet);
> +          Double[] doubles = getOutputVector(input);
>            return selectionFunction.selectOutput(Arrays.asList(doubles));
>          } catch (Exception e) {
>            throw new PredictionException(e);
>
> Modified:
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
>
> ==============================================================================
> ---
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
> (original)
> +++
> labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
> Mon Sep 28 16:49:57 2015
> @@ -19,6 +19,8 @@
>  package org.apache.yay.core.utils;
>
>  import java.util.ArrayList;
> +import java.util.Collection;
> +
>  import org.apache.yay.Feature;
>  import org.apache.yay.Input;
>  import org.apache.yay.TrainingExample;
> @@ -41,6 +43,21 @@ public class ExamplesFactory {
>          return output;
>        }
>      };
> +  }
> +
> +  public static TrainingExample<Double, Collection<Double[]>>
> createSGMExample(final Collection<Double[]> output,
> +
>   final Double... featuresValues) {
> +    return new TrainingExample<Double, Collection<Double[]>>() {
> +      @Override
> +      public ArrayList<Feature<Double>> getFeatures() {
> +        return doublesToFeatureVector(featuresValues);
> +      }
> +
> +      @Override
> +      public Collection<Double[]> getOutput() {
> +        return output;
> +      }
> +    };
>    }
>
>    public static Input<Double> createDoubleInput(final Double...
> featuresValues) {
>
> Modified:
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
> URL:
> http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff
>
> ==============================================================================
> ---
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
> (original)
> +++
> labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
> Mon Sep 28 16:49:57 2015
> @@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest {
>    public void sampleCreationTest() throws Exception {
>      RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d,
> 1d, 2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}});
>      RealMatrix secondLayer = new Array2DRowRealMatrix(new double[][]{{1d,
> 2d, 3d}});
> +
>      RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer};
> +
>      NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
> +
>      Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d,
> 7d));
>      assertEquals(1l, Math.round(prdictedValue));
>      assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);
>
>
>
> ---------------------------------------------------------------------
> To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
> For additional commands, e-mail: commits-h...@labs.apache.org
>
>

Reply via email to