Author: srowen
Date: Tue Jan 18 16:28:08 2011
New Revision: 1060450
URL: http://svn.apache.org/viewvc?rev=1060450&view=rev
Log:
MAHOUT-533 make stdev calculation more accurate in clustering
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/GaussianAccumulator.java
Tue Jan 18 16:28:08 2011
@@ -50,8 +50,9 @@ public interface GaussianAccumulator {
* Observe the vector
*
* @param x a Vector
+ * @param weight the double observation weight (usually 1.0)
*/
- void observe(Vector x);
+ void observe(Vector x, double weight);
/**
* Compute the mean, variance and standard deviation
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/OnlineGaussianAccumulator.java
Tue Jan 18 16:28:08 2011
@@ -24,17 +24,15 @@ import org.apache.mahout.math.function.S
* numerically-stable. See
http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
*/
public class OnlineGaussianAccumulator implements GaussianAccumulator {
- private double n = 0;
+ private double sumWeight = 0.0;
private Vector mean;
-
- private Vector m2;
-
+ private Vector s;
private Vector variance;
@Override
public double getN() {
- return n;
+ return sumWeight;
}
@Override
@@ -47,23 +45,44 @@ public class OnlineGaussianAccumulator i
return variance.clone().assign(new SquareRootFunction());
}
- @Override
- public void observe(Vector x) {
- n++;
- Vector delta;
- if (mean != null) {
- delta = x.minus(mean);
- } else {
+ /* from Wikipedia:
http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance
+ *
+ * Weighted incremental algorithm
+ *
+ * def weighted_incremental_variance(dataWeightPairs):
+ * mean = 0
+ * S = 0
+ * sumweight = 0
+ * for x, weight in dataWeightPairs: # Alternately "for x in zip(data,
weight):"
+ * temp = weight + sumweight
+ * Q = x - mean
+ * R = Q * weight / temp
+ * S = S + sumweight * Q * R
+ * mean = mean + R
+ * sumweight = temp
+ * Variance = S / (sumweight-1) # if sample is the population, omit -1
+ * return Variance
+ */
+
+ @Override
+ public void observe(Vector x, double weight) {
+ double temp = weight + sumWeight;
+ Vector Q;
+ if (mean == null) {
mean = x.like();
- delta = x.clone();
+ Q = x.clone();
+ } else {
+ Q = x.minus(mean);
}
- mean = mean.plus(delta.divide(n));
- if (m2 != null) {
- m2 = m2.plus(delta.times(x.minus(mean)));
+ Vector R = Q.times(weight).divide(temp);
+ if (s == null) {
+ s = Q.times(sumWeight).times(R);
} else {
- m2 = delta.times(x.minus(mean));
+ s = s.plus(Q.times(sumWeight).times(R));
}
- variance = m2.divide(n - 1);
+ mean = mean.plus(R);
+ sumWeight = temp;
+ variance = s.divide(sumWeight - 1);// # if sample is the population, omit
-1
}
@Override
@@ -73,8 +92,8 @@ public class OnlineGaussianAccumulator i
@Override
public double getAverageStd() {
- if (n == 0) {
- return 0;
+ if (sumWeight == 0.0) {
+ return 0.0;
} else {
Vector std = getStd();
return std.zSum() / std.size();
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/RunningSumsGaussianAccumulator.java
Tue Jan 18 16:28:08 2011
@@ -25,14 +25,11 @@ import org.apache.mahout.math.function.S
* Suffers from overflow, underflow and roundoff error but has minimal
observe-time overhead
*/
public class RunningSumsGaussianAccumulator implements GaussianAccumulator {
- private double s0 = 0;
+ private double s0 = 0.0;
private Vector s1;
-
private Vector s2;
-
private Vector mean;
-
private Vector std;
@Override
@@ -52,8 +49,8 @@ public class RunningSumsGaussianAccumula
@Override
public double getAverageStd() {
- if (s0 == 0) {
- return 0;
+ if (s0 == 0.0) {
+ return 0.0;
} else {
return std.zSum() / std.size();
}
@@ -65,14 +62,15 @@ public class RunningSumsGaussianAccumula
}
@Override
- public void observe(Vector x) {
- s0++;
+ public void observe(Vector x, double weight) {
+ s0 += weight;
+ Vector weightedX = x.times(weight);
if (s1 == null) {
- s1 = x.clone();
+ s1 = weightedX;
} else {
- x.addTo(s1);
+ weightedX.addTo(s1);
}
- Vector x2 = x.times(x);
+ Vector x2 = x.times(x).times(weight);
if (s2 == null) {
s2 = x2;
} else {
@@ -82,11 +80,10 @@ public class RunningSumsGaussianAccumula
@Override
public void compute() {
- if (s0 == 0) {
- return;
+ if (s0 != 0.0) {
+ mean = s1.divide(s0);
+ std = s2.times(s0).minus(s1.times(s1)).assign(new
SquareRootFunction()).divide(s0);
}
- mean = s1.divide(s0);
- std = s2.times(s0).minus(s1.times(s1)).assign(new
SquareRootFunction()).divide(s0);
}
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java
Tue Jan 18 16:28:08 2011
@@ -22,7 +22,9 @@ import java.util.Collection;
import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.SquareRootFunction;
import org.junit.Before;
import org.junit.Test;
import org.slf4j.Logger;
@@ -30,37 +32,38 @@ import org.slf4j.LoggerFactory;
public final class TestGaussianAccumulators extends MahoutTestCase {
- private Collection<VectorWritable> sampleData = new
ArrayList<VectorWritable>();
-
private static final Logger log =
LoggerFactory.getLogger(TestGaussianAccumulators.class);
+ private Collection<VectorWritable> sampleData = new
ArrayList<VectorWritable>();
+ private int sampleN;
+ private Vector sampleMean;
+ private Vector sampleStd;
+ private Vector sampleVar;
+
@Override
@Before
public void setUp() throws Exception {
super.setUp();
sampleData = new ArrayList<VectorWritable>();
generateSamples();
- }
+ sampleN = 0;
+ Vector sum = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ v.get().addTo(sum);
+ sampleN++;
+ }
+ sampleMean = sum.divide(sampleN);
- /**
- * Generate random samples and add them to the sampleData
- *
- * @param num
- * int number of samples to generate
- * @param mx
- * double x-value of the sample mean
- * @param my
- * double y-value of the sample mean
- * @param sd
- * double standard deviation of the samples
- * @throws Exception
- */
- private void generateSamples(int num, double mx, double my, double sd) {
- log.info("Generating {} samples m=[{}, {}] sd={}", new Object[] { num, mx,
my, sd });
- for (int i = 0; i < num; i++) {
- sampleData.add(new VectorWritable(new DenseVector(new double[] {
UncommonDistributions.rNorm(mx, sd),
- UncommonDistributions.rNorm(my, sd) })));
+ sampleVar = new DenseVector(2);
+ for (VectorWritable v : sampleData) {
+ Vector delta = v.get().minus(sampleMean);
+ delta.times(delta).addTo(sampleVar);
}
+ sampleVar = sampleVar.divide(sampleN - 1);
+ sampleStd = sampleVar.clone();
+ sampleStd.assign(new SquareRootFunction());
+ log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]",
+ new Object[] { sampleN, sampleMean.get(0), sampleMean.get(1),
sampleStd.get(0), sampleStd.get(1) });
}
/**
@@ -86,7 +89,7 @@ public final class TestGaussianAccumulat
}
private void generateSamples() {
- generate2dSamples(500, 1, 2, 3, 4);
+ generate2dSamples(50000, 1, 2, 3, 4);
}
@Test
@@ -101,18 +104,76 @@ public final class TestGaussianAccumulat
}
@Test
- public void testAccumulatorResults() {
+ public void testAccumulatorOneSample() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ Vector sample = new DenseVector(2);
+ accumulator0.observe(sample, 1.0);
+ accumulator1.observe(sample, 1.0);
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean(), accumulator1.getMean());
+ assertEquals("Avg Stds", accumulator0.getAverageStd(),
accumulator1.getAverageStd(), EPSILON);
+ }
+
+ @Test
+ public void testOLAccumulatorResults() {
+ GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[] {
accumulator.getN(), accumulator.getMean().get(0),
+ accumulator.getMean().get(1), accumulator.getStd().get(0),
accumulator.getStd().get(1) });
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(),
EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(),
EPSILON);
+ }
+
+ @Test
+ public void testRSAccumulatorResults() {
+ GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator.observe(vw.get(), 1.0);
+ }
+ accumulator.compute();
+ log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]", new Object[] {
(int) accumulator.getN(),
+ accumulator.getMean().get(0), accumulator.getMean().get(1),
accumulator.getStd().get(0), accumulator.getStd().get(1) });
+ assertEquals("OL N", sampleN, accumulator.getN(), EPSILON);
+ assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(),
EPSILON);
+ assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(),
0.0001);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults() {
+ GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
+ GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
+ for (VectorWritable vw : sampleData) {
+ accumulator0.observe(vw.get(), 0.5);
+ accumulator1.observe(vw.get(), 0.5);
+ }
+ accumulator0.compute();
+ accumulator1.compute();
+ assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
+ assertEquals("Means", accumulator0.getMean().zSum(),
accumulator1.getMean().zSum(), EPSILON);
+ assertEquals("Stds", accumulator0.getStd().zSum(),
accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(),
accumulator1.getVariance().zSum(), 0.01);
+ }
+
+ @Test
+ public void testAccumulatorWeightedResults2() {
GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator();
GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator();
for (VectorWritable vw : sampleData) {
- accumulator0.observe(vw.get());
- accumulator1.observe(vw.get());
+ accumulator0.observe(vw.get(), 1.5);
+ accumulator1.observe(vw.get(), 1.5);
}
accumulator0.compute();
accumulator1.compute();
assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON);
assertEquals("Means", accumulator0.getMean().zSum(),
accumulator1.getMean().zSum(), EPSILON);
- assertEquals("Stds", accumulator0.getStd().zSum(),
accumulator1.getStd().zSum(), 0.01);
- assertEquals("Variance", accumulator0.getVariance().zSum(),
accumulator1.getVariance().zSum(), 0.1);
+ assertEquals("Stds", accumulator0.getStd().zSum(),
accumulator1.getStd().zSum(), 0.001);
+ assertEquals("Variance", accumulator0.getVariance().zSum(),
accumulator1.getVariance().zSum(), 0.01);
}
}
Modified:
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java?rev=1060450&r1=1060449&r2=1060450&view=diff
==============================================================================
---
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
(original)
+++
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwEvaluator.java
Tue Jan 18 16:28:08 2011
@@ -50,13 +50,9 @@ public class CDbwEvaluator {
private static final Logger log =
LoggerFactory.getLogger(CDbwEvaluator.class);
private final Map<Integer, List<VectorWritable>> representativePoints;
-
private final Map<Integer, Double> stDevs = new HashMap<Integer, Double>();
-
private final List<Cluster> clusters;
-
private final DistanceMeasure measure;
-
private boolean pruned;
/**
@@ -136,7 +132,7 @@ public class CDbwEvaluator {
List<VectorWritable> repPts = representativePoints.get(cI);
GaussianAccumulator accumulator = new OnlineGaussianAccumulator();
for (VectorWritable vw : repPts) {
- accumulator.observe(vw.get());
+ accumulator.observe(vw.get(), 1.0);
}
accumulator.compute();
double d = accumulator.getAverageStd();