Author: tommaso
Date: Fri Oct 11 12:12:55 2013
New Revision: 1531265
URL: http://svn.apache.org/r1531265
Log:
HAMA-809 - small refactoring of LRM
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
Modified:
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
URL:
http://svn.apache.org/viewvc/hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java?rev=1531265&r1=1531264&r2=1531265&view=diff
==============================================================================
---
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
(original)
+++
hama/trunk/ml/src/main/java/org/apache/hama/ml/regression/LogisticRegressionModel.java
Fri Oct 11 12:12:55 2013
@@ -35,17 +35,14 @@ public class LogisticRegressionModel imp
costFunction = new CostFunction() {
@Override
public BigDecimal calculateCostForItem(DoubleVector x, double y, int m,
DoubleVector theta,
- HypothesisFunction hypothesis) {
+ HypothesisFunction hypothesis) {
// -1/m*(y*ln(hx) + (1-y)*ln(1-hx))
BigDecimal hx = applyHypothesisWithPrecision(theta, x);
- BigDecimal first = BigDecimal.valueOf(y).multiply(ln(hx));
- BigDecimal logarg = BigDecimal.valueOf(1).subtract(hx,
DEFAULT_PRECISION);
- BigDecimal ln = ln(logarg);
- BigDecimal second = BigDecimal.valueOf(1d - y).multiply(ln);
- BigDecimal num = first.add(second);
- BigDecimal den = BigDecimal.valueOf(-1*m);
- BigDecimal res = num.divide(den, DEFAULT_PRECISION);
- return res;
+ BigDecimal firstTerm = BigDecimal.valueOf(y).multiply(ln(hx));
+ BigDecimal secondTerm = BigDecimal.valueOf(1d -
y).multiply(ln(BigDecimal.valueOf(1).subtract(hx, DEFAULT_PRECISION)));
+ BigDecimal num = firstTerm.add(secondTerm);
+ BigDecimal den = BigDecimal.valueOf(-1 * m);
+ return num.divide(den, DEFAULT_PRECISION);
}
};
}
@@ -58,11 +55,9 @@ public class LogisticRegressionModel imp
private BigDecimal applyHypothesisWithPrecision(DoubleVector theta,
DoubleVector x) {
// 1 / (1 + (e^(-theta'x)))
double dotUnsafe = theta.multiply(-1d).dotUnsafe(x);
- double d = Math.exp(dotUnsafe);
- BigDecimal exp = BigDecimal.valueOf(d);
- BigDecimal den = BigDecimal.valueOf(1d).add(exp);
- BigDecimal remainder = BigDecimal.valueOf(1).subtract(den,
DEFAULT_PRECISION);
+ BigDecimal den =
BigDecimal.valueOf(1d).add(BigDecimal.valueOf(Math.exp(dotUnsafe)));
BigDecimal res = BigDecimal.valueOf(1).divide(den, DEFAULT_PRECISION);
+ BigDecimal remainder = BigDecimal.valueOf(1).subtract(den,
DEFAULT_PRECISION);
if (res.doubleValue() == 1 && remainder.doubleValue() < 0) {
res = res.add(remainder);
}
@@ -70,23 +65,7 @@ public class LogisticRegressionModel imp
}
private BigDecimal ln(BigDecimal x) {
-// if (x.equals(BigDecimal.ONE)) {
-// return BigDecimal.ZERO;
-// }
-// x = x.subtract(BigDecimal.ONE);
-// int iterations = 10000000;
-// BigDecimal ret = new BigDecimal(iterations + 1);
-// for (long i = iterations; i >= 0; i--) {
-// BigDecimal N = new BigDecimal(i / 2 + 1).pow(2);
-// N = N.multiply(x, DEFAULT_PRECISION);
-// ret = N.divide(ret, DEFAULT_PRECISION);
-//
-// N = new BigDecimal(i + 1);
-// ret = ret.add(N, DEFAULT_PRECISION);
-//
-// }
-// ret = x.divide(ret, DEFAULT_PRECISION);
-// return ret;
+ // TODO : implement this using proper logarithm for BigDecimals
return BigDecimal.valueOf(Math.log(x.doubleValue()));
}