Author: tommaso
Date: Sun Oct 28 07:14:27 2012
New Revision: 1402944
URL: http://svn.apache.org/viewvc?rev=1402944&view=rev
Log:
[HAMA-660] - total number of items to read is pre calculated and passed to the
CostFunction
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
--- hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
(original)
+++ hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/CostFunction.java
Sun Oct 28 07:14:27 2012
@@ -30,10 +30,11 @@ public interface CostFunction {
*
* @param x the input vector
* @param y the learned output for x
+ * @param m the number of existing items
* @param theta the parameters vector theta
* @param hypothesis the hypothesis function to model the problem
* @return the calculated cost for input x and output y
*/
- public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta, HypothesisFunction hypothesis);
+ public double calculateCostForItem(DoubleVector x, double y, int m,
DoubleVector theta, HypothesisFunction hypothesis);
}
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Sun Oct 28 07:14:27 2012
@@ -49,6 +49,7 @@ public class GradientDescentBSP extends
private float alpha;
private RegressionModel regressionModel;
private int iterationsThreshold;
+ private int m;
@Override
public void setup(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
@@ -66,6 +67,30 @@ public class GradientDescentBSP extends
@Override
public void bsp(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {
+ // 0 superstep : count items
+
+ int itemCount = 0;
+ while (peer.readNext() != null) {
+ // increment counter
+ itemCount++;
+ }
+ for (String peerName : peer.getAllPeerNames()) {
+ if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
+ peer.send(peerName, new VectorWritable(new DenseDoubleVector(new
double[]{itemCount})));
+ }
+ }
+ peer.sync();
+
+ // aggregate number of items
+ VectorWritable itemsResult;
+ while ((itemsResult = peer.getCurrentMessage()) != null) {
+ itemCount += itemsResult.getVector().get(0);
+ }
+
+ m = itemCount;
+
+ peer.reopenInput();
+
int iterations = 0;
while (true) {
@@ -75,25 +100,22 @@ public class GradientDescentBSP extends
double localCost = 0d;
- int numRead = 0;
-
- // read an input
+ // read an item
KeyValuePair<VectorWritable, DoubleWritable> kvp;
while ((kvp = peer.readNext()) != null) {
// calculate cost for given input
double y = kvp.getValue().get();
DoubleVector x = kvp.getKey().getVector();
- double costForX = regressionModel.calculateCostForItem(x, y, theta);
+ double costForX = regressionModel.calculateCostForItem(x, y, m, theta);
// adds to local cost
localCost += costForX;
- numRead++;
}
// cost is sent and aggregated by each
for (String peerName : peer.getAllPeerNames()) {
if (!peerName.equals(peer.getPeerName())) { // avoid sending to oneself
- peer.send(peerName, new VectorWritable(new DenseDoubleVector(new
double[]{localCost, numRead})));
+ peer.send(peerName, new VectorWritable(new DenseDoubleVector(new
double[]{localCost})));
}
}
peer.sync();
@@ -103,11 +125,8 @@ public class GradientDescentBSP extends
VectorWritable costResult;
while ((costResult = peer.getCurrentMessage()) != null) {
totalCost += costResult.getVector().get(0);
- numRead += costResult.getVector().get(1);
}
- totalCost /= numRead; // TODO : remove this and incorporate the 1/m
element in RegressionModel#calculateCostForItem
-
// cost check
if (cost - totalCost < 0) {
throw new RuntimeException(new StringBuilder("gradient descent failed
to converge with alpha ").
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LinearRegressionModel.java
Sun Oct 28 07:14:27 2012
@@ -29,8 +29,8 @@ public class LinearRegressionModel imple
public LinearRegressionModel() {
costFunction = new CostFunction() {
@Override
- public double calculateCostForItem(DoubleVector x, double y,
DoubleVector theta, HypothesisFunction hypothesis) {
- return y * Math.pow(applyHypothesis(theta, x) - y, 2) / 2;
+ public double calculateCostForItem(DoubleVector x, double y, int m,
DoubleVector theta, HypothesisFunction hypothesis) {
+ return y * Math.pow(applyHypothesis(theta, x) - y, 2) / (2 * m);
}
};
}
@@ -41,7 +41,7 @@ public class LinearRegressionModel imple
}
@Override
- public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta) {
- return costFunction.calculateCostForItem(x, y, theta, this);
+ public double calculateCostForItem(DoubleVector x, double y, int m,
DoubleVector theta) {
+ return costFunction.calculateCostForItem(x, y, m, theta, this);
}
}
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
Sun Oct 28 07:14:27 2012
@@ -29,8 +29,8 @@ public class LogisticRegressionModel imp
public LogisticRegressionModel() {
costFunction = new CostFunction() {
@Override
- public double calculateCostForItem(DoubleVector x, double y,
DoubleVector theta, HypothesisFunction hypothesis) {
- return -1 * y * Math.log(applyHypothesis(theta, x)) + (1 - y) *
Math.log(1 - applyHypothesis(theta, x));
+ public double calculateCostForItem(DoubleVector x, double y, int m,
DoubleVector theta, HypothesisFunction hypothesis) {
+ return (-1 * y * Math.log(applyHypothesis(theta, x)) + (1 - y) *
Math.log(1 - applyHypothesis(theta, x))) / m;
}
};
}
@@ -41,7 +41,7 @@ public class LogisticRegressionModel imp
}
@Override
- public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta) {
- return costFunction.calculateCostForItem(x, y, theta, this);
+ public double calculateCostForItem(DoubleVector x, double y, int m,
DoubleVector theta) {
+ return costFunction.calculateCostForItem(x, y, m, theta, this);
}
}
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java?rev=1402944&r1=1402943&r2=1402944&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/RegressionModel.java
Sun Oct 28 07:14:27 2012
@@ -30,9 +30,10 @@ public interface RegressionModel extends
*
* @param x the input vector
* @param y the learned output for x
+ * @param m the total number of existing items
* @param theta the parameters vector theta
* @return the calculated cost for input x and output y
*/
- public double calculateCostForItem(DoubleVector x, double y, DoubleVector
theta);
+ public double calculateCostForItem(DoubleVector x, double y, int m,
DoubleVector theta);
}