Author: tommaso
Date: Sat Oct  6 12:26:04 2012
New Revision: 1395024

URL: http://svn.apache.org/viewvc?rev=1395024&view=rev
Log:
[HAMA-651] - adding collecting of cost and theta as output

Modified:
    
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java

Modified: 
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL: 
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395024&r1=1395023&r2=1395024&view=diff
==============================================================================
--- 
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
 (original)
+++ 
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
 Sat Oct  6 12:26:04 2012
@@ -35,7 +35,7 @@ import java.io.IOException;
  * A gradient descent (see 
<code>http://en.wikipedia.org/wiki/Gradient_descent</code>) BSP based abstract 
implementation.
  * Each extending class should implement the #hypothesis(DoubleVector theta, 
DoubleVector x) method for a specific
  */
-public abstract class GradientDescentBSP extends BSP<VectorWritable, 
DoubleWritable, NullWritable, NullWritable, VectorWritable> {
+public abstract class GradientDescentBSP extends BSP<VectorWritable, 
DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> {
 
   private static final Logger log = 
LoggerFactory.getLogger(GradientDescentBSP.class);
   static final String INITIAL_THETA_VALUES = "initial.theta.values";
@@ -45,12 +45,12 @@ public abstract class GradientDescentBSP
   private DoubleVector theta;
 
   @Override
-  public void setup(BSPPeer<VectorWritable, DoubleWritable, NullWritable, 
NullWritable, VectorWritable> peer) throws IOException, SyncException, 
InterruptedException {
+  public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, 
DoubleWritable, VectorWritable> peer) throws IOException, SyncException, 
InterruptedException {
     master = peer.getPeerIndex() == peer.getNumPeers() / 2;
   }
 
   @Override
-  public void bsp(BSPPeer<VectorWritable, DoubleWritable, NullWritable, 
NullWritable, VectorWritable> peer) throws IOException, SyncException, 
InterruptedException {
+  public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, 
DoubleWritable, VectorWritable> peer) throws IOException, SyncException, 
InterruptedException {
 
     while (true) {
 
@@ -137,6 +137,10 @@ public abstract class GradientDescentBSP
         if (log.isInfoEnabled()) {
           log.info("new theta for cost " + totalCost + " is " + 
theta.toArray().toString());
         }
+        // master writes down the output
+        if (master) {
+          peer.write(new VectorWritable(theta), new DoubleWritable(totalCost));
+        }
       }
       peer.sync();
 
@@ -159,7 +163,7 @@ public abstract class GradientDescentBSP
   public abstract double hypothesis(DoubleVector theta, DoubleVector x);
 
 
-  public void getTheta(BSPPeer<VectorWritable, DoubleWritable, NullWritable, 
NullWritable, VectorWritable> peer) throws IOException, SyncException, 
InterruptedException {
+  public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, 
DoubleWritable, VectorWritable> peer) throws IOException, SyncException, 
InterruptedException {
     if (master && theta == null) {
       int size = getXSize(peer);
       theta = new DenseDoubleVector(size, 
peer.getConfiguration().getInt(INITIAL_THETA_VALUES, 10));
@@ -174,7 +178,7 @@ public abstract class GradientDescentBSP
     }
   }
 
-  private int getXSize(BSPPeer<VectorWritable, DoubleWritable, NullWritable, 
NullWritable, VectorWritable> peer) throws IOException {
+  private int getXSize(BSPPeer<VectorWritable, DoubleWritable, VectorWritable, 
DoubleWritable, VectorWritable> peer) throws IOException {
     VectorWritable key = null;
     peer.readNext(key, null);
     peer.reopenInput(); // reset input to start


Reply via email to