Author: tommaso
Date: Sat Oct 6 12:26:04 2012
New Revision: 1395024
URL: http://svn.apache.org/viewvc?rev=1395024&view=rev
Log:
[HAMA-651] - adding collecting of cost and theta as output
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395024&r1=1395023&r2=1395024&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Sat Oct 6 12:26:04 2012
@@ -35,7 +35,7 @@ import java.io.IOException;
* A gradient descent (see
<code>http://en.wikipedia.org/wiki/Gradient_descent</code>) BSP based abstract
implementation.
* Each extending class should implement the #hypothesis(DoubleVector theta,
DoubleVector x) method for a specific
*/
-public abstract class GradientDescentBSP extends BSP<VectorWritable,
DoubleWritable, NullWritable, NullWritable, VectorWritable> {
+public abstract class GradientDescentBSP extends BSP<VectorWritable,
DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> {
private static final Logger log =
LoggerFactory.getLogger(GradientDescentBSP.class);
static final String INITIAL_THETA_VALUES = "initial.theta.values";
@@ -45,12 +45,12 @@ public abstract class GradientDescentBSP
private DoubleVector theta;
@Override
- public void setup(BSPPeer<VectorWritable, DoubleWritable, NullWritable,
NullWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
+ public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
master = peer.getPeerIndex() == peer.getNumPeers() / 2;
}
@Override
- public void bsp(BSPPeer<VectorWritable, DoubleWritable, NullWritable,
NullWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
+ public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
while (true) {
@@ -137,6 +137,10 @@ public abstract class GradientDescentBSP
if (log.isInfoEnabled()) {
log.info("new theta for cost " + totalCost + " is " +
theta.toArray().toString());
}
+ // master writes down the output
+ if (master) {
+ peer.write(new VectorWritable(theta), new DoubleWritable(totalCost));
+ }
}
peer.sync();
@@ -159,7 +163,7 @@ public abstract class GradientDescentBSP
public abstract double hypothesis(DoubleVector theta, DoubleVector x);
- public void getTheta(BSPPeer<VectorWritable, DoubleWritable, NullWritable,
NullWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
+ public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
if (master && theta == null) {
int size = getXSize(peer);
theta = new DenseDoubleVector(size,
peer.getConfiguration().getInt(INITIAL_THETA_VALUES, 10));
@@ -174,7 +178,7 @@ public abstract class GradientDescentBSP
}
}
- private int getXSize(BSPPeer<VectorWritable, DoubleWritable, NullWritable,
NullWritable, VectorWritable> peer) throws IOException {
+ private int getXSize(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException {
VectorWritable key = null;
peer.readNext(key, null);
peer.reopenInput(); // reset input to start