Author: tommaso
Date: Mon Sep 28 16:49:57 2015
New Revision: 1705721

URL: http://svn.apache.org/viewvc?rev=1705721&view=rev
Log:
switch from batch to stochastic GD in backprop, abstracted derivative update 
function

Added:
    
labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
Removed:
    labs/yay/trunk/api/src/main/java/org/apache/yay/HypothesisFactory.java
Modified:
    labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
    labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
    
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
    
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java

Added: 
labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java?rev=1705721&view=auto
==============================================================================
--- 
labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java 
(added)
+++ 
labs/yay/trunk/api/src/main/java/org/apache/yay/DerivativeUpdateFunction.java 
Mon Sep 28 16:49:57 2015
@@ -0,0 +1,30 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay;
+
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.yay.TrainingSet;
+
+/**
+ * Derivatives update function
+ */
+public interface DerivativeUpdateFunction<F,O> {
+
+  RealMatrix[] updateParameters(RealMatrix[] weights, TrainingSet<F,O> 
trainingSet);
+}

Modified: labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java 
(original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/NeuralNetwork.java Mon Sep 
28 16:49:57 2015
@@ -21,8 +21,19 @@ package org.apache.yay;
 import org.apache.commons.math3.linear.RealMatrix;
 
 /**
- * A neural network is a layered connected graph of elaboration units
+ * A Neural Network is a layered connected graph of elaboration units.
+ *
+ * It takes a double vector as input and produces a double vector as output.
  */
 public interface NeuralNetwork extends Hypothesis<RealMatrix, Double, Double> {
 
+  /**
+   * Predict the output for a given input
+   *
+   * @param input the input to evaluate
+   * @return the predicted output
+   * @throws PredictionException if any error occurs during the prediction 
phase
+   */
+  Double[] getOutputVector(Input<Double> input) throws PredictionException;
+
 }

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BackPropagationLearningStrategy.java
 Mon Sep 28 16:49:57 2015
@@ -19,6 +19,7 @@
 package org.apache.yay.core;
 
 import java.util.Arrays;
+import java.util.DoubleSummaryStatistics;
 import java.util.Iterator;
 
 import org.apache.commons.math3.linear.Array2DRowRealMatrix;
@@ -26,6 +27,7 @@ import org.apache.commons.math3.linear.A
 import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.CostFunction;
+import org.apache.yay.DerivativeUpdateFunction;
 import org.apache.yay.LearningStrategy;
 import org.apache.yay.NeuralNetwork;
 import org.apache.yay.PredictionStrategy;
@@ -46,6 +48,7 @@ public class BackPropagationLearningStra
 
   private final PredictionStrategy<Double, Double> predictionStrategy;
   private final CostFunction<RealMatrix, Double, Double> costFunction;
+  private final DerivativeUpdateFunction<Double, Double> 
derivativeUpdateFunction;
   private final double alpha;
   private final double threshold;
   private final int batch;
@@ -63,6 +66,7 @@ public class BackPropagationLearningStra
     this.alpha = alpha;
     this.threshold = threshold;
     this.batch = batch;
+    this.derivativeUpdateFunction = new 
DefaultDerivativeUpdateFunction(predictionStrategy);
   }
 
   public BackPropagationLearningStrategy() {
@@ -72,6 +76,7 @@ public class BackPropagationLearningStra
     this.alpha = DEFAULT_ALPHA;
     this.threshold = DEFAULT_THRESHOLD;
     this.batch = 1;
+    this.derivativeUpdateFunction = new 
DefaultDerivativeUpdateFunction(predictionStrategy);
   }
 
   @Override
@@ -114,7 +119,7 @@ public class BackPropagationLearningStra
         cost = newCost;
 
         // calculate the derivatives to update the parameters
-        RealMatrix[] derivatives = calculateDerivatives(weightsMatrixSet, 
samples);
+        RealMatrix[] derivatives = 
derivativeUpdateFunction.updateParameters(weightsMatrixSet, samples);
 
         // calculate the updated parameters
         updatedWeights = updateWeights(weightsMatrixSet, derivatives, alpha);
@@ -131,48 +136,6 @@ public class BackPropagationLearningStra
     return updatedWeights;
   }
 
-  private RealMatrix[] calculateDerivatives(RealMatrix[] weightsMatrixSet, 
TrainingSet<Double, Double> trainingExamples) throws WeightLearningException {
-    // set up the accumulator matrix(es)
-    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
-    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
-
-    int noOfMatrixes = weightsMatrixSet.length - 1;
-    double count = 0;
-    for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
-      try {
-        // get activations from feed forward propagation
-        RealVector[] activations = 
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
 weightsMatrixSet);
-
-        // calculate output error (corresponding to the last delta^l)
-        RealVector nextLayerDelta = calculateOutputError(trainingExample, 
activations);
-
-        deltaVectors[noOfMatrixes] = nextLayerDelta;
-
-        // back prop the error and update the deltas accordingly
-        for (int l = noOfMatrixes; l > 0; l--) {
-          RealVector currentActivationsVector = activations[l - 1];
-          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], 
currentActivationsVector, nextLayerDelta);
-
-          // collect delta vectors for this example
-          deltaVectors[l - 1] = nextLayerDelta;
-        }
-
-        RealVector[] newActivations = new RealVector[activations.length];
-        newActivations[0] = 
ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
-        System.arraycopy(activations, 0, newActivations, 1, activations.length 
- 1);
-
-        // update triangle (big delta matrix)
-        updateTriangle(triangle, newActivations, deltaVectors, 
weightsMatrixSet);
-
-      } catch (Exception e) {
-        throw new WeightLearningException("error during derivatives 
calculation", e);
-      }
-      count++;
-    }
-
-    return createDerivatives(triangle, count);
-  }
-
   private RealMatrix[] updateWeights(RealMatrix[] weightsMatrixSet, 
RealMatrix[] derivatives, double alpha) {
     RealMatrix[] updatedParameters = new RealMatrix[weightsMatrixSet.length];
     for (int l = 0; l < weightsMatrixSet.length; l++) {
@@ -187,48 +150,4 @@ public class BackPropagationLearningStra
     return updatedParameters;
   }
 
-  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
-    RealMatrix[] derivatives = new RealMatrix[triangle.length];
-    for (int i = 0; i < triangle.length; i++) {
-      // TODO : introduce regularization diversification on bias term 
(currently not regularized)
-      derivatives[i] = triangle[i].scalarMultiply(1d / count);
-    }
-    return derivatives;
-  }
-
-  private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, 
RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
-    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
-      RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
-      if (triangle[l] == null) {
-        triangle[l] = realMatrix;
-      } else {
-        triangle[l] = triangle[l].add(realMatrix);
-      }
-    }
-  }
-
-  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector 
activationsVector, RealVector nextLayerDelta) {
-    // TODO : remove the bias term from the error calculations
-    ArrayRealVector identity = new 
ArrayRealVector(activationsVector.getDimension(), 1d);
-    RealVector gz = 
activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l 
.* (1-a^l)
-    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
-  }
-
-  private RealVector calculateOutputError(TrainingExample<Double, Double> 
trainingExample, RealVector[] activations) {
-    RealVector output = activations[activations.length - 1];
-
-    Double[] sampleOutput = new Double[output.getDimension()];
-    int sampleOutputIntValue = trainingExample.getOutput().intValue();
-    if (sampleOutputIntValue < sampleOutput.length) {
-      sampleOutput[sampleOutputIntValue] = 1d;
-    } else if (sampleOutput.length == 1) {
-      sampleOutput[0] = trainingExample.getOutput();
-    } else {
-      throw new RuntimeException("problem with multiclass output mapping");
-    }
-    RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // 
turn example output to a vector
-
-    // TODO : improve error calculation -> this could be er_a = out_a * (1 - 
out_a) * (tgt_a - out_a)
-    return output.subtract(learnedOutputRealVector);
-  }
 }

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/core/BasicPerceptron.java 
Mon Sep 28 16:49:57 2015
@@ -94,4 +94,13 @@ public class BasicPerceptron implements
     return 
perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
             new Double[input.getFeatures().size()]));
   }
+
+  @Override
+  public Double[] getOutputVector(Input<Double> input) throws 
PredictionException {
+    Double elaborate = 
perceptronNeuron.elaborate(ConversionUtils.toValuesCollection(input.getFeatures()).toArray(
+            new Double[input.getFeatures().size()]));
+    Double[] ar = new Double[1];
+    ar[0] = elaborate;
+    return ar;
+  }
 }

Added: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java?rev=1705721&view=auto
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
 (added)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/DefaultDerivativeUpdateFunction.java
 Mon Sep 28 16:49:57 2015
@@ -0,0 +1,128 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import org.apache.commons.math3.linear.ArrayRealVector;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
+import org.apache.yay.DerivativeUpdateFunction;
+import org.apache.yay.PredictionStrategy;
+import org.apache.yay.TrainingExample;
+import org.apache.yay.TrainingSet;
+import org.apache.yay.core.utils.ConversionUtils;
+
+/**
+ * Default derivatives update function
+ */
+public class DefaultDerivativeUpdateFunction implements 
DerivativeUpdateFunction<Double, Double> {
+
+  private final PredictionStrategy<Double, Double> predictionStrategy;
+
+  public DefaultDerivativeUpdateFunction(PredictionStrategy<Double, Double> 
predictionStrategy) {
+    this.predictionStrategy = predictionStrategy;
+  }
+
+  @Override
+  public RealMatrix[] updateParameters(RealMatrix[] weightsMatrixSet, 
TrainingSet<Double, Double> trainingExamples) {
+    // set up the accumulator matrix(es)
+    RealMatrix[] triangle = new RealMatrix[weightsMatrixSet.length];
+    RealVector[] deltaVectors = new RealVector[weightsMatrixSet.length];
+
+    int noOfMatrixes = weightsMatrixSet.length - 1;
+    double count = 0;
+    for (TrainingExample<Double, Double> trainingExample : trainingExamples) {
+      try {
+        // get activations from feed forward propagation
+        RealVector[] activations = 
predictionStrategy.debugOutput(ConversionUtils.toValuesCollection(trainingExample.getFeatures()),
 weightsMatrixSet);
+
+        // calculate output error (corresponding to the last delta^l)
+        RealVector nextLayerDelta = calculateOutputError(trainingExample, 
activations);
+
+        deltaVectors[noOfMatrixes] = nextLayerDelta;
+
+        // back prop the error and update the deltas accordingly
+        for (int l = noOfMatrixes; l > 0; l--) {
+          RealVector currentActivationsVector = activations[l - 1];
+          nextLayerDelta = calculateDeltaVector(weightsMatrixSet[l], 
currentActivationsVector, nextLayerDelta);
+
+          // collect delta vectors for this example
+          deltaVectors[l - 1] = nextLayerDelta;
+        }
+
+        RealVector[] newActivations = new RealVector[activations.length];
+        newActivations[0] = 
ConversionUtils.toRealVector(ConversionUtils.toValuesCollection(trainingExample.getFeatures()));
+        System.arraycopy(activations, 0, newActivations, 1, activations.length 
- 1);
+
+        // update triangle (big delta matrix)
+        updateTriangle(triangle, newActivations, deltaVectors, 
weightsMatrixSet);
+
+      } catch (Exception e) {
+        throw new RuntimeException("error during derivatives calculation", e);
+      }
+      count++;
+    }
+
+    return createDerivatives(triangle, count);
+  }
+
+  private RealMatrix[] createDerivatives(RealMatrix[] triangle, double count) {
+    RealMatrix[] derivatives = new RealMatrix[triangle.length];
+    for (int i = 0; i < triangle.length; i++) {
+      // TODO : introduce regularization diversification on bias term 
(currently not regularized)
+      derivatives[i] = triangle[i].scalarMultiply(1d / count);
+    }
+    return derivatives;
+  }
+
+  private void updateTriangle(RealMatrix[] triangle, RealVector[] activations, 
RealVector[] deltaVectors, RealMatrix[] weightsMatrixSet) {
+    for (int l = weightsMatrixSet.length - 1; l >= 0; l--) {
+      RealMatrix realMatrix = deltaVectors[l].outerProduct(activations[l]);
+      if (triangle[l] == null) {
+        triangle[l] = realMatrix;
+      } else {
+        triangle[l] = triangle[l].add(realMatrix);
+      }
+    }
+  }
+
+  private RealVector calculateDeltaVector(RealMatrix thetaL, RealVector 
activationsVector, RealVector nextLayerDelta) {
+    // TODO : remove the bias term from the error calculations
+    ArrayRealVector identity = new 
ArrayRealVector(activationsVector.getDimension(), 1d);
+    RealVector gz = 
activationsVector.ebeMultiply(identity.subtract(activationsVector)); // = a^l 
.* (1-a^l)
+    return thetaL.preMultiply(nextLayerDelta).ebeMultiply(gz);
+  }
+
+  private RealVector calculateOutputError(TrainingExample<Double, Double> 
trainingExample, RealVector[] activations) {
+    RealVector output = activations[activations.length - 1];
+
+    Double[] sampleOutput = new Double[output.getDimension()];
+    int sampleOutputIntValue = trainingExample.getOutput().intValue();
+    if (sampleOutputIntValue < sampleOutput.length) {
+      sampleOutput[sampleOutputIntValue] = 1d;
+    } else if (sampleOutput.length == 1) {
+      sampleOutput[0] = trainingExample.getOutput();
+    } else {
+      throw new RuntimeException("problem with multiclass output mapping");
+    }
+    RealVector learnedOutputRealVector = new ArrayRealVector(sampleOutput); // 
turn example output to a vector
+
+    // TODO : improve error calculation -> this could be er_a = out_a * (1 - 
out_a) * (tgt_a - out_a)
+    return output.subtract(learnedOutputRealVector);
+  }
+}

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java 
(original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java 
Mon Sep 28 16:49:57 2015
@@ -16,7 +16,6 @@
  * specific language governing permissions and limitations
  * under the License.
  */
-
 package org.apache.yay.core;
 
 import java.util.ArrayList;

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java 
(original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/NeuralNetworkFactory.java 
Mon Sep 28 16:49:57 2015
@@ -18,9 +18,11 @@
  */
 package org.apache.yay.core;
 
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
 import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.RealVector;
 import org.apache.yay.CreationException;
 import org.apache.yay.Input;
 import org.apache.yay.LearningException;
@@ -53,6 +55,12 @@ public class NeuralNetworkFactory {
                                      final 
SelectionFunction<Collection<Double>, Double> selectionFunction) throws 
CreationException {
     return new NeuralNetwork() {
 
+      @Override
+      public Double[] getOutputVector(Input<Double> input) throws 
PredictionException {
+        Collection<Double> inputVector = 
ConversionUtils.toValuesCollection(input.getFeatures());
+        return predictionStrategy.predictOutput(inputVector, 
updatedRealMatrixSet);
+      }
+
       private RealMatrix[] updatedRealMatrixSet = realMatrixSet;
 
       @Override
@@ -77,8 +85,7 @@ public class NeuralNetworkFactory {
       @Override
       public Double predict(Input<Double> input) throws PredictionException {
         try {
-          Collection<Double> inputVector = 
ConversionUtils.toValuesCollection(input.getFeatures());
-          Double[] doubles = predictionStrategy.predictOutput(inputVector, 
updatedRealMatrixSet);
+          Double[] doubles = getOutputVector(input);
           return selectionFunction.selectOutput(Arrays.asList(doubles));
         } catch (Exception e) {
           throw new PredictionException(e);

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
 (original)
+++ 
labs/yay/trunk/core/src/main/java/org/apache/yay/core/utils/ExamplesFactory.java
 Mon Sep 28 16:49:57 2015
@@ -19,6 +19,8 @@
 package org.apache.yay.core.utils;
 
 import java.util.ArrayList;
+import java.util.Collection;
+
 import org.apache.yay.Feature;
 import org.apache.yay.Input;
 import org.apache.yay.TrainingExample;
@@ -41,6 +43,21 @@ public class ExamplesFactory {
         return output;
       }
     };
+  }
+
+  public static TrainingExample<Double, Collection<Double[]>> 
createSGMExample(final Collection<Double[]> output,
+                                                                            
final Double... featuresValues) {
+    return new TrainingExample<Double, Collection<Double[]>>() {
+      @Override
+      public ArrayList<Feature<Double>> getFeatures() {
+        return doublesToFeatureVector(featuresValues);
+      }
+
+      @Override
+      public Collection<Double[]> getOutput() {
+        return output;
+      }
+    };
   }
 
   public static Input<Double> createDoubleInput(final Double... 
featuresValues) {

Modified: 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java?rev=1705721&r1=1705720&r2=1705721&view=diff
==============================================================================
--- 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
 (original)
+++ 
labs/yay/trunk/core/src/test/java/org/apache/yay/core/NeuralNetworkFactoryTest.java
 Mon Sep 28 16:49:57 2015
@@ -84,8 +84,11 @@ public class NeuralNetworkFactoryTest {
   public void sampleCreationTest() throws Exception {
     RealMatrix firstLayer = new Array2DRowRealMatrix(new double[][]{{1d, 1d, 
2d, 3d}, {1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}});
     RealMatrix secondLayer = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 
3d}});
+
     RealMatrix[] RealMatrixes = new RealMatrix[]{firstLayer, secondLayer};
+
     NeuralNetwork neuralNetwork = createFFNN(RealMatrixes);
+
     Double prdictedValue = neuralNetwork.predict(createSample(5d, 6d, 7d));
     assertEquals(1l, Math.round(prdictedValue));
     assertEquals(Double.valueOf(0.9975273768433653d), prdictedValue);



---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@labs.apache.org
For additional commands, e-mail: commits-h...@labs.apache.org

Reply via email to