Author: jeastman
Date: Fri Dec 23 01:01:58 2011
New Revision: 1222524
URL: http://svn.apache.org/viewvc?rev=1222524&view=rev
Log:
MAHOUT-846: Cache pdf zProd2piR term constant over life of cluster
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java?rev=1222524&r1=1222523&r2=1222524&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Fri Dec 23 01:01:58 2011
@@ -48,27 +48,47 @@ public class GaussianCluster extends Abs
return new GaussianCluster(getCenter(), getRadius(), getId());
}
+ // the value of the zProduct(S*2pi) term. Calculated below.
+ private Double zProd2piR = null;
+
@Override
public double pdf(VectorWritable vw) {
- Vector x = vw.get();
- Vector m = getCenter();
- Vector s = getRadius().plus(0.0000001); // add a small prior to avoid
divide by zero
- return Math.exp(-(divideSquareAndSum(x.minus(m), s) / 2)) / zProdSqt2Pi(s);
+ if (zProd2piR == null) {
+ computeProd2piR();
+ }
+ return Math.exp(-(sumXminusCdivRsquared(vw.get()) / 2)) / zProd2piR;
}
- private double zProdSqt2Pi(Vector s) {
- double prod = 1;
- for (int i = 0; i < s.size(); i++) {
- prod *= s.getQuick(i) * UncommonDistributions.SQRT2PI;
+ /**
+ * Compute the product(r[i]*SQRT2PI) over all i. Note that the cluster Radius
+ * corresponds to the Stdev of a Gaussian and the Center to its Mean.
+ */
+ private void computeProd2piR() {
+ zProd2piR = 1.0;
+ for (Iterator<Element> it = getRadius().iterateNonZero(); it.hasNext();) {
+ Element radius = it.next();
+ zProd2piR *= radius.get() * UncommonDistributions.SQRT2PI;
}
- return prod;
}
- private double divideSquareAndSum(Vector numerator, Vector denominator) {
+ @Override
+ public void computeParameters() {
+ super.computeParameters();
+ zProd2piR = null;
+ }
+
+ /**
+ * @param x
+ * a Vector
+ * @return the zSum(((x[i]-c[i])/r[i])^2) over all i
+ */
+ private double sumXminusCdivRsquared(Vector x) {
double result = 0;
- for (Iterator<Element> it = denominator.iterateNonZero(); it.hasNext();) {
- Element denom = it.next();
- double quotient = numerator.getQuick(denom.index()) / denom.get();
+ for (Iterator<Element> it = getRadius().iterateNonZero(); it.hasNext();) {
+ Element radiusElem = it.next();
+ int index = radiusElem.index();
+ double quotient = (x.get(index) - getCenter().get(index))
+ / radiusElem.get();
result += quotient * quotient;
}
return result;