Author: tommaso
Date: Fri Oct 12 11:46:34 2012
New Revision: 1397520
URL: http://svn.apache.org/viewvc?rev=1397520&view=rev
Log:
[HAMA-651] - adjusting defaults a bit, plus fixing input reopenings, theta
initialization and derivatives aggregation
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1397520&r1=1397519&r2=1397520&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Fri Oct 12 11:46:34 2012
@@ -52,8 +52,8 @@ public class GradientDescentBSP extends
public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
master = peer.getPeerIndex() == peer.getNumPeers() / 2;
cost = Integer.MAX_VALUE;
- threshold = peer.getConfiguration().getFloat(THRESHOLD, 0.01f);
- alpha = peer.getConfiguration().getFloat(ALPHA, 0.3f);
+ threshold = peer.getConfiguration().getFloat(THRESHOLD, 0.1f);
+ alpha = peer.getConfiguration().getFloat(ALPHA, 0.003f);
try {
regressionModel = ((Class<? extends RegressionModel>)
peer.getConfiguration().getClass(REGRESSION_MODEL_CLASS,
LinearRegressionModel.class)).newInstance();
} catch (Exception e) {
@@ -77,7 +77,6 @@ public class GradientDescentBSP extends
// read an input
KeyValuePair<VectorWritable, DoubleWritable> kvp;
while ((kvp = peer.readNext()) != null) {
-
// calculate cost for given input
double y = kvp.getValue().get();
DoubleVector x = kvp.getKey().getVector();
@@ -89,15 +88,15 @@ public class GradientDescentBSP extends
}
// cost is sent and aggregated by each
- double totalCost = localCost;
-
for (String peerName : peer.getAllPeerNames()) {
- peer.send(peerName, new VectorWritable(new DenseDoubleVector(new
double[]{localCost, numRead})));
+ if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
+ peer.send(peerName, new VectorWritable(new DenseDoubleVector(new
double[]{localCost, numRead})));
+ }
}
peer.sync();
// second superstep : aggregate cost calculation
-
+ double totalCost = localCost;
VectorWritable costResult;
while ((costResult = peer.getCurrentMessage()) != null) {
totalCost += costResult.getVector().get(0);
@@ -106,27 +105,23 @@ public class GradientDescentBSP extends
totalCost /= numRead; // TODO : remove this and incorporate the 1/m
element in RegressionModel#calculateCostForItem
+ // cost check
if (cost - totalCost < 0) {
throw new RuntimeException("gradient descent failed to converge with
alpha " + alpha);
- } else if (totalCost == 0 || cost - totalCost < threshold) {
+ } else if (totalCost == 0 || totalCost < threshold) {
+ log.info(peer.getPeerName()+": finishing!");
cost = totalCost;
break;
} else {
cost = totalCost;
+ if (log.isInfoEnabled()) {
+ log.info(peer.getPeerName()+": cost is " + cost);
+ }
}
-
- if (log.isInfoEnabled()) {
- log.info("cost is " + cost);
- }
-
-
+ peer.reopenInput();
peer.sync();
- if (master) { // TODO : check if this has to be done only by the master
- peer.reopenInput();
- }
-
double[] thetaDelta = new double[theta.getLength()];
// third superstep : calculate partial derivatives' deltas in parallel
@@ -148,8 +143,8 @@ public class GradientDescentBSP extends
// fourth superstep : aggregate partial derivatives
VectorWritable thetaDeltaSlice;
+ double[] newTheta = thetaDelta;
while ((thetaDeltaSlice = peer.getCurrentMessage()) != null) {
- double[] newTheta = new double[theta.getLength()];
for (int j = 0; j < theta.getLength(); j++) {
newTheta[j] += thetaDeltaSlice.getVector().get(j);
@@ -158,17 +153,18 @@ public class GradientDescentBSP extends
for (int j = 0; j < theta.getLength(); j++) {
newTheta[j] = theta.get(j) - newTheta[j] * alpha;
}
+ }
+ theta = new DenseDoubleVector(newTheta);
- theta = new DenseDoubleVector(newTheta);
-
- if (log.isInfoEnabled()) {
- log.info("new theta for cost " + cost + " is " +
theta.toArray().toString());
- }
- // master writes down the output
- if (master) {
- peer.write(new VectorWritable(theta), new DoubleWritable(cost));
- }
+ if (log.isInfoEnabled()) {
+ log.info(peer.getPeerName()+": new theta for cost " + cost + " is " +
theta.toString());
+ }
+ // master writes down the output
+ if (master) {
+ peer.write(new VectorWritable(theta), new DoubleWritable(cost));
}
+
+ peer.reopenInput();
peer.sync();
}
@@ -178,7 +174,7 @@ public class GradientDescentBSP extends
@Override
public void cleanup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException {
if (log.isInfoEnabled()) {
- log.info("computation finished with cost " + cost + " for theta " +
theta);
+ log.info(peer.getPeerName()+":computation finished with cost " + cost
+ " for theta " + theta);
}
// master writes down the final output
if (master) {
@@ -187,23 +183,28 @@ public class GradientDescentBSP extends
}
public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
- if (master && theta == null) {
- int size = getXSize(peer);
- theta = new DenseDoubleVector(size,
peer.getConfiguration().getInt(INITIAL_THETA_VALUES, 10));
- for (String peerName : peer.getAllPeerNames()) {
- peer.send(peerName, new VectorWritable(theta));
- }
- peer.sync();
- } else {
- peer.sync();
- VectorWritable vectorWritable = peer.getCurrentMessage();
- theta = vectorWritable.getVector();
+ if (theta == null) {
+ if (master) {
+ int size = getXSize(peer);
+ theta = new DenseDoubleVector(size,
peer.getConfiguration().getInt(INITIAL_THETA_VALUES, 10));
+ for (String peerName : peer.getAllPeerNames()) {
+ peer.send(peerName, new VectorWritable(theta));
+ }
+ log.info(peer.getPeerName() + ": sending theta");
+ peer.sync();
+ } else {
+ log.info(peer.getPeerName() + ": getting theta");
+ peer.sync();
+ VectorWritable vectorWritable = peer.getCurrentMessage();
+ theta = vectorWritable.getVector();
+ }
}
}
private int getXSize(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException {
- VectorWritable key = null;
- peer.readNext(key, null);
+ VectorWritable key = new VectorWritable();
+ DoubleWritable value = new DoubleWritable();
+ peer.readNext(key, value);
peer.reopenInput(); // reset input to start
if (key == null) {
throw new IOException("cannot read input vector size");