Author: tommaso
Date: Tue Jan 24 22:30:06 2012
New Revision: 1235536

URL: http://svn.apache.org/viewvc?rev=1235536&view=rev
Log:
starting implementing the feed forward stuff

Modified:
    labs/yay/trunk/core/pom.xml
    labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java

Modified: labs/yay/trunk/core/pom.xml
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/pom.xml?rev=1235536&r1=1235535&r2=1235536&view=diff
==============================================================================
--- labs/yay/trunk/core/pom.xml (original)
+++ labs/yay/trunk/core/pom.xml Tue Jan 24 22:30:06 2012
@@ -1,27 +1,32 @@
 <project
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/maven-v4_0_0.xsd";
         xmlns="http://maven.apache.org/POM/4.0.0"; 
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance";>
-  <modelVersion>4.0.0</modelVersion>
-  <artifactId>core</artifactId>
-  <version>0.1-SNAPSHOT</version>
-  <parent>
-    <groupId>org.apache.yay</groupId>
-    <artifactId>parent</artifactId>
+    <modelVersion>4.0.0</modelVersion>
+    <artifactId>core</artifactId>
     <version>0.1-SNAPSHOT</version>
-    <relativePath>../</relativePath>
-  </parent>
-  <name>Yay core</name>
-  <dependencies>
-    <dependency>
-      <groupId>org.mockito</groupId>
-      <artifactId>mockito-core</artifactId>
-      <version>1.9.0-rc1</version>
-      <scope>test</scope>
-    </dependency>
-    <dependency>
-      <groupId>org.apache.commons</groupId>
-      <artifactId>commons-math</artifactId>
-      <version>2.2</version>
-    </dependency>
-  </dependencies>
+    <parent>
+        <groupId>org.apache.yay</groupId>
+        <artifactId>parent</artifactId>
+        <version>0.1-SNAPSHOT</version>
+        <relativePath>../</relativePath>
+    </parent>
+    <name>Yay core</name>
+    <dependencies>
+        <dependency>
+            <groupId>org.mockito</groupId>
+            <artifactId>mockito-core</artifactId>
+            <version>1.9.0-rc1</version>
+            <scope>test</scope>
+        </dependency>
+        <dependency>
+            <groupId>org.apache.commons</groupId>
+            <artifactId>commons-math</artifactId>
+            <version>2.2</version>
+        </dependency>
+        <dependency>
+            <groupId>commons-collections</groupId>
+            <artifactId>commons-collections</artifactId>
+            <version>3.2.1</version>
+        </dependency>
+    </dependencies>
 </project>
\ No newline at end of file

Modified: 
labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java
URL: 
http://svn.apache.org/viewvc/labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java?rev=1235536&r1=1235535&r2=1235536&view=diff
==============================================================================
--- labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java 
(original)
+++ labs/yay/trunk/core/src/main/java/org/apache/yay/NeuralNetworkFactory.java 
Tue Jan 24 22:30:06 2012
@@ -18,9 +18,12 @@
  */
 package org.apache.yay;
 
+import org.apache.commons.collections.CollectionUtils;
+import org.apache.commons.collections.Transformer;
 import org.apache.commons.math.linear.RealMatrix;
 import org.apache.yay.bio.Signal;
 
+import java.util.ArrayList;
 import java.util.Collection;
 import java.util.Set;
 
@@ -42,24 +45,52 @@ public class NeuralNetworkFactory {
       @Override
       public Signal predict(Signal... input) {
         /* Octave code to be converted :
-        * m = size(X, 1);
-        num_labels = size(Theta2, 1);
+       * m = size(X, 1);
+       num_labels = size(Theta2, 1);
 
-        p = zeros(size(X, 1), 1);
+       p = zeros(size(X, 1), 1);
 
-        h1 = sigmoid([ones(m, 1) X] * Theta1');
-        h2 = sigmoid([ones(m, 1) h1] * Theta2');
-        [dummy, p] = max(h2, [], 2);
+       h1 = sigmoid([ones(m, 1) X] * Theta1');
+       h2 = sigmoid([ones(m, 1) h1] * Theta2');
+       [dummy, p] = max(h2, [], 2);
 
-         */
+        */
 
+        // TODO : fix this impl as it's very slow and commons-math Java1.4 
compatibility introduces more complexity
+
+        final SigmoidFunction sigmoidFunction = new SigmoidFunction();
         RealMatrix x = matrixConverter.toMatrix(trainingSet);
         for (RealMatrix weightsMatrix : weightsMatrixes) {
+          // compute matrix multiplication
           x = weightsMatrix.transpose().multiply(x);
-          // TODO : apply SigmoidFunction
+          // apply SigmoidFunction
+          for (int i = 0; i < x.getRowDimension(); i++) {
+            double[] doubles = x.getRow(i);
+            ArrayList<Double> row = new ArrayList<Double>();
+            for (int j = 0; j < doubles.length; j++) {
+              row.set(j, doubles[j]);
+            }
+            CollectionUtils.transform(row, new Transformer() {
+              @Override
+              public Object transform(Object input) {
+                assert input instanceof Double;
+                final Double d = (Double) input;
+                return sigmoidFunction.apply(new Signal<Double>() {
+                  @Override
+                  public Double getValue() {
+                    return d;
+                  }
+                });
+              }
+            });
+            double[] finRow = new double[row.size()];
+            for (int h = 0; h < finRow.length; h++) {
+              finRow[h] = row.get(h);
+            }
+            x.setRow(i, finRow);
+          }
         }
 
-
         return null;
       }
 



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to