Author: tommaso
Date: Wed Apr 17 07:45:12 2013
New Revision: 1468786
URL: http://svn.apache.org/r1468786
Log:
fixed prediction strategy javadoc, refactoring of FF strategy and added FF
strategy test
Added:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
(with props)
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
Modified:
labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java?rev=1468786&r1=1468785&r2=1468786&view=diff
==============================================================================
--- labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java
(original)
+++ labs/yay/trunk/api/src/main/java/org/apache/yay/PredictionStrategy.java Wed
Apr 17 07:45:12 2013
@@ -39,11 +39,11 @@ public interface PredictionStrategy<I, O
/**
* Perform a prediction on the given input values and weights settings
returning
- * an debug output.
+ * a debug output.
*
* @param inputs a collection of input values
* @param weightsMatrixSet the initial set of weights defined by an array of
matrix
- * @return the perturbed neural network state via its weights matrix array
+ * @return the perturbed neural network state via its activations values
*/
public RealMatrix[] debugOutput(Collection<I> inputs, RealMatrix[]
weightsMatrixSet);
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java?rev=1468786&r1=1468785&r2=1468786&view=diff
==============================================================================
---
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
(original)
+++
labs/yay/trunk/core/src/main/java/org/apache/yay/core/FeedForwardStrategy.java
Wed Apr 17 07:45:12 2013
@@ -49,27 +49,27 @@ public class FeedForwardStrategy impleme
}
@Override
- public Double[] predictOutput(Collection<Double> input, RealMatrix[]
RealMatrixSet) {
- RealMatrix[] realMatrixes = applyFF(input, RealMatrixSet);
+ public Double[] predictOutput(Collection<Double> input, RealMatrix[]
realMatrixSet) {
+ RealMatrix[] realMatrixes = applyFF(input, realMatrixSet);
RealMatrix x = realMatrixes[realMatrixes.length - 1];
double[] lastColumn = x.getColumn(x.getColumnDimension() - 1);
return ConversionUtils.toDoubleArray(lastColumn);
}
- public RealMatrix[] debugOutput(Collection<Double> input, RealMatrix[]
RealMatrixSet) {
- return applyFF(input, RealMatrixSet);
+ public RealMatrix[] debugOutput(Collection<Double> input, RealMatrix[]
realMatrixSet) {
+ return applyFF(input, realMatrixSet);
}
- private RealMatrix[] applyFF(Collection<Double> input, RealMatrix[]
RealMatrixSet) {
- RealMatrix[] debugOutput = new RealMatrix[RealMatrixSet.length];
+ private RealMatrix[] applyFF(Collection<Double> input, RealMatrix[]
realMatrixSet) {
+ RealMatrix[] debugOutput = new RealMatrix[realMatrixSet.length];
// TODO : fix this impl as it's very slow
RealVector v = ConversionUtils.toRealVector(input);
RealMatrix x = v.outerProduct(new ArrayRealVector(new
Double[]{1d})).transpose(); // a 1xN matrix
- for (int w = 0; w < RealMatrixSet.length; w++) {
- RealMatrix RealMatrix = RealMatrixSet[w];
+ for (int w = 0; w < realMatrixSet.length; w++) {
+ RealMatrix currentWeightsMatrix = realMatrixSet[w];
// compute matrix multiplication
- x = x.multiply(RealMatrix.transpose());
+ x = x.multiply(currentWeightsMatrix.transpose());
// apply the activation function to each element in the matrix
for (int i = 0; i < x.getRowDimension(); i++) {
@@ -78,7 +78,7 @@ public class FeedForwardStrategy impleme
for (int j = 0; j < doubles.length; j++) {
row.add(j, doubles[j]);
}
- CollectionUtils.transform(row, new HypothesisRowTransformer());
+ CollectionUtils.transform(row, new ActivationRowTransformer());
double[] finRow = new double[row.size()];
for (int h = 0; h < finRow.length; h++) {
finRow[h] = row.get(h);
@@ -90,7 +90,7 @@ public class FeedForwardStrategy impleme
return debugOutput;
}
- private class HypothesisRowTransformer implements Transformer {
+ private class ActivationRowTransformer implements Transformer {
@Override
public Object transform(Object input) {
assert input instanceof Double;
Added:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java?rev=1468786&view=auto
==============================================================================
---
labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
(added)
+++
labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
Wed Apr 17 07:45:12 2013
@@ -0,0 +1,50 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.yay.core;
+
+import java.util.Collection;
+import java.util.LinkedList;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.junit.Test;
+
+import static junit.framework.Assert.assertNotNull;
+
+/**
+ * Testcase for {@link FeedForwardStrategy}
+ */
+public class FeedForwardStrategyTest {
+
+ @Test
+ public void testOutputDebugging() throws Exception {
+ FeedForwardStrategy feedForwardStrategy = new FeedForwardStrategy(new
SigmoidFunction());
+ RealMatrix[] weights = new RealMatrix[2];
+ weights[0] = new Array2DRowRealMatrix(new double[][]{{1d, 1d, 2d, 3d},
{1d, 1d, 2d, 3d}, {1d, 1d, 2d, 3d}});
+ weights[1] = new Array2DRowRealMatrix(new double[][]{{1d, 2d, 3d}});
+
+ Collection<Double> inputs = new LinkedList<Double>();
+ inputs.add(1d);
+ inputs.add(2d);
+ inputs.add(-5d);
+ inputs.add(7d);
+ RealMatrix[] activations = feedForwardStrategy.debugOutput(inputs,
weights);
+ assertNotNull(activations);
+ }
+}
Propchange:
labs/yay/trunk/core/src/test/java/org/apache/yay/core/FeedForwardStrategyTest.java
------------------------------------------------------------------------------
svn:eol-style = native
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]