Author: tommaso
Date: Tue Jan 24 22:30:06 2012
New Revision: 1235536
URL: http://svn.apache.org/viewvc?rev=1235536&view=rev
Log:
starting implementing the feed forward stuff
Modified:
labs/yay/trunk/core/pom.xml
labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
Modified: labs/yay/trunk/core/pom.xml
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1235536&r1=1235535&r2=1235536&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Tue Jan 24 22:30:06 2012
@@ -1,27 +1,32 @@
<project
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0
http://maven.apache.org/maven-v4_0_0.xsd"
xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance">
- <modelVersion>4.0.0</modelVersion>
- <artifactId>core</artifactId>
- <version>0.1-SNAPSHOT</version>
- <parent>
- <groupId>org.apache.yay</groupId>
- <artifactId>parent</artifactId>
+ <modelVersion>4.0.0</modelVersion>
+ <artifactId>core</artifactId>
<version>0.1-SNAPSHOT</version>
- <relativePath>../</relativePath>
- </parent>
- <name>Yay core</name>
- <dependencies>
- <dependency>
- <groupId>org.mockito</groupId>
- <artifactId>mockito-core</artifactId>
- <version>1.9.0-rc1</version>
- <scope>test</scope>
- </dependency>
- <dependency>
- <groupId>org.apache.commons</groupId>
- <artifactId>commons-math</artifactId>
- <version>2.2</version>
- </dependency>
- </dependencies>
+ <parent>
+ <groupId>org.apache.yay</groupId>
+ <artifactId>parent</artifactId>
+ <version>0.1-SNAPSHOT</version>
+ <relativePath>../</relativePath>
+ </parent>
+ <name>Yay core</name>
+ <dependencies>
+ <dependency>
+ <groupId>org.mockito</groupId>
+ <artifactId>mockito-core</artifactId>
+ <version>1.9.0-rc1</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>org.apache.commons</groupId>
+ <artifactId>commons-math</artifactId>
+ <version>2.2</version>
+ </dependency>
+ <dependency>
+ <groupId>commons-collections</groupId>
+ <artifactId>commons-collections</artifactId>
+ <version>3.2.1</version>
+ </dependency>
+ </dependencies>
</project>
\ No newline at end of file
Modified:
labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
URL:
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java?rev=1235536&r1=1235535&r2=1235536&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
Tue Jan 24 22:30:06 2012
@@ -18,9 +18,12 @@
*/
package org.apache.yay;
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.Transformer;
import org.apache.commons.math.linear.RealMatrix;
import org.apache.yay.bio.Signal;
+import java.util.ArrayList;
import java.util.Collection;
import java.util.Set;
@@ -42,24 +45,52 @@ public class NeuralNetworkFactory {
@Override
public Signal predict(Signal... input) {
/* Octave code to be converted :
- * m = size(X, 1);
- num_labels = size(Theta2, 1);
+ * m = size(X, 1);
+ num_labels = size(Theta2, 1);
- p = zeros(size(X, 1), 1);
+ p = zeros(size(X, 1), 1);
- h1 = sigmoid([ones(m, 1) X] * Theta1');
- h2 = sigmoid([ones(m, 1) h1] * Theta2');
- [dummy, p] = max(h2, [], 2);
+ h1 = sigmoid([ones(m, 1) X] * Theta1');
+ h2 = sigmoid([ones(m, 1) h1] * Theta2');
+ [dummy, p] = max(h2, [], 2);
- */
+ */
+ // TODO : fix this impl as it's very slow and commons-math Java1.4
compatibility introduces more complexity
+
+ final SigmoidFunction sigmoidFunction = new SigmoidFunction();
RealMatrix x = matrixConverter.toMatrix(trainingSet);
for (RealMatrix weightsMatrix : weightsMatrixes) {
+ // compute matrix multiplication
x = weightsMatrix.transpose().multiply(x);
- // TODO : apply SigmoidFunction
+ // apply SigmoidFunction
+ for (int i = 0; i < x.getRowDimension(); i++) {
+ double[] doubles = x.getRow(i);
+ ArrayList<Double> row = new ArrayList<Double>();
+ for (int j = 0; j < doubles.length; j++) {
+ row.set(j, doubles[j]);
+ }
+ CollectionUtils.transform(row, new Transformer() {
+ @Override
+ public Object transform(Object input) {
+ assert input instanceof Double;
+ final Double d = (Double) input;
+ return sigmoidFunction.apply(new Signal<Double>() {
+ @Override
+ public Double getValue() {
+ return d;
+ }
+ });
+ }
+ });
+ double[] finRow = new double[row.size()];
+ for (int h = 0; h < finRow.length; h++) {
+ finRow[h] = row.get(h);
+ }
+ x.setRow(i, finRow);
+ }
}
-
return null;
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]