Author: tommaso
Date: Sat Oct 6 19:51:17 2012
New Revision: 1395161
URL: http://svn.apache.org/viewvc?rev=1395161&view=rev
Log:
[HAMA-651] - added calculateCostForItem method since also cost function is
dependent from the specific algorithm used
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java?rev=1395161&r1=1395160&r2=1395161&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/GradientDescentBSP.java
Sat Oct 6 19:51:17 2012
@@ -18,7 +18,6 @@
package org.apache.hama.ml.regression;
import org.apache.hadoop.io.DoubleWritable;
-import org.apache.hadoop.io.NullWritable;
import org.apache.hama.bsp.BSP;
import org.apache.hama.bsp.BSPPeer;
import org.apache.hama.bsp.sync.SyncException;
@@ -33,7 +32,7 @@ import java.io.IOException;
/**
* A gradient descent (see
<code>http://en.wikipedia.org/wiki/Gradient_descent</code>) BSP based abstract
implementation.
- * Each extending class should implement the #hypothesis(DoubleVector theta,
DoubleVector x) method for a specific
+ * Each extending class should implement the #applyHypothesis(DoubleVector
theta, DoubleVector x) method for a specific
*/
public abstract class GradientDescentBSP extends BSP<VectorWritable,
DoubleWritable, VectorWritable, DoubleWritable, VectorWritable> {
@@ -69,7 +68,7 @@ public abstract class GradientDescentBSP
// calculate cost for given input
double y = kvp.getValue().get();
DoubleVector x = kvp.getKey().getVector();
- double costForX = y * Math.log(hypothesis(theta, x)) + (1 - y) *
Math.log(1 - hypothesis(theta, x));
+ double costForX = calculateCostForItem(y, x, theta);
// adds to local cost
localCost += costForX;
@@ -84,7 +83,7 @@ public abstract class GradientDescentBSP
}
peer.sync();
- // second superstep : cost calculation
+ // second superstep : aggregate cost calculation
VectorWritable costResult;
while ((costResult = peer.getCurrentMessage()) != null) {
@@ -92,7 +91,8 @@ public abstract class GradientDescentBSP
numRead += costResult.getVector().get(1);
}
- totalCost = totalCost * (-1 / numRead);
+ totalCost /= numRead;
+
if (log.isInfoEnabled()) {
log.info("cost is " + totalCost);
}
@@ -103,11 +103,11 @@ public abstract class GradientDescentBSP
double[] thetaDelta = new double[theta.getLength()];
- // second superstep : calculate partial derivatives in parallel
+ // third superstep : calculate partial derivatives' deltas in parallel
while ((kvp = peer.readNext()) != null) {
DoubleVector x = kvp.getKey().getVector();
double y = kvp.getValue().get();
- double difference = hypothesis(theta, x) - y;
+ double difference = applyHypothesis(theta, x) - y;
for (int j = 0; j < theta.getLength(); j++) {
thetaDelta[j] += difference * x.get(j);
}
@@ -120,6 +120,7 @@ public abstract class GradientDescentBSP
peer.sync();
+ // fourth superstep : aggregate partial derivatives
VectorWritable thetaDeltaSlice;
while ((thetaDeltaSlice = peer.getCurrentMessage()) != null) {
double[] newTheta = new double[theta.getLength()];
@@ -154,13 +155,22 @@ public abstract class GradientDescentBSP
}
/**
- * Applies the hypothesis given a set of parameters theta to a given input x
+ * Calculates the cost function for a given item (input x, output y)
+ * @param y the learned output for x
+ * @param x the input vector
+ * @param theta the parameters vector theta
+ * @return the calculated cost for input x and output y
+ */
+ protected abstract double calculateCostForItem(double y, DoubleVector x,
DoubleVector theta);
+
+ /**
+ * Applies the applyHypothesis given a set of parameters theta to a given
input x
*
* @param theta the parameters vector
* @param x the input
* @return a <code>double</code> number
*/
- public abstract double hypothesis(DoubleVector theta, DoubleVector x);
+ public abstract double applyHypothesis(DoubleVector theta, DoubleVector x);
public void getTheta(BSPPeer<VectorWritable, DoubleWritable, VectorWritable,
DoubleWritable, VectorWritable> peer) throws IOException, SyncException,
InterruptedException {