OnlineSummarizerTest seems to be failing both tests

On 7/21/10 5:59 PM, [email protected] wrote:
Author: tdunning
Date: Thu Jul 22 00:59:03 2010
New Revision: 966469

URL: http://svn.apache.org/viewvc?rev=966469&view=rev
Log:
MAHOUT-444 - on-line estimator widget with tests that demonstrate accuracy.

Added:
     
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
     
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java

Added: 
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java?rev=966469&view=auto
==============================================================================
--- 
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
 (added)
+++ 
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineSummarizer.java
 Thu Jul 22 00:59:03 2010
@@ -0,0 +1,166 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.math.list.DoubleArrayList;
+
+/**
+ * Computes on-line estimates of mean, variance and all five quartiles 
(notably including the
+ * median).  Since this is done in a completely incremental fashion (that is 
what is meant by
+ * on-line) estimates are available at any time and the amount of memory used 
is constant.  Somewhat
+ * surprisingly, the quantile estimates are about as good as you would get if 
you actually kept all
+ * of the samples.
+ *<p/>
+ * The method used for mean and variance is Welford's method.  See
+ *<p/>
+ * 
http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#On-line_algorithm
+ *<p/>
+ * The method used for computing the quartiles is a simplified form of the 
stochastic approximation
+ * method described in the article "Incremental Quantile Estimation for Massive 
Tracking" by Chen,
+ * Lambert and Pinheiro
+ *<p/>
+ * See
+ *<p/>
+ * http://citeseerx.ist.psu.edu/viewdoc/summary?doi=10.1.1.105.1580
+ */
+public class OnlineSummarizer {
+  boolean sorted = true;
+
+  // the first several samples are kept so we can boot-strap our estimates 
cleanly
+  private DoubleArrayList starter = new DoubleArrayList(100);
+
+  // quartile estimates
+  private double q[] = new double[5];
+
+  // mean and variance estimates
+  private double mean = 0;
+  private double variance = 0;
+
+  // number of samples seen so far
+  private int n = 0;
+
+  public void add(double sample) {
+    sorted = false;
+
+    n++;
+    double oldMean = mean;
+    mean += (sample - mean) / n;
+    double diff = (sample - mean) * (sample - oldMean);
+    variance += (diff - variance) / n;
+
+    if (n<  100) {
+      starter.add(sample);
+    } else if (n == 100) {
+      starter.add(sample);
+      q[0] = min();
+      q[1] = quartile(1);
+      q[2] = quartile(2);
+      q[3] = quartile(3);
+      q[4] = max();
+      starter = null;
+    } else {
+      q[0] = Math.min(sample, q[0]);
+      q[4] = Math.max(sample, q[4]);
+
+      double rate = 2 * (q[3] - q[1]) / n;
+      q[1] += (Math.signum(sample - q[1]) - 0.5) * rate;
+      q[2] += (Math.signum(sample - q[2])) * rate;
+      q[3] += (Math.signum(sample - q[3]) + 0.5) * rate;
+
+      if (q[1]<  q[0]) {
+        q[1] = q[0];
+      }
+
+      if (q[3]>  q[4]) {
+        q[3] = q[4];
+      }
+    }
+  }
+
+  public int count() {
+    return n;
+  }
+
+  public double mean() {
+    return mean;
+  }
+
+  public double sd() {
+    return Math.sqrt(variance);
+  }
+
+  public double min() {
+    sort();
+    if (n == 0) {
+      throw new IllegalArgumentException("Must have at least one sample to estimate 
minimum value");
+    }
+    if (n<= 100) {
+      return starter.get(0);
+    } else {
+      return q[0];
+    }
+  }
+
+  private void sort() {
+    if (!sorted&&  starter != null) {
+      starter.sortFromTo(0, 99);
+      sorted = true;
+    }
+  }
+
+  public double max() {
+    sort();
+    if (n == 0) {
+      throw new IllegalArgumentException("Must have at least one sample to estimate 
maximum value");
+    }
+    if (n<= 100) {
+      return starter.get(99);
+    } else {
+      return q[4];
+    }
+  }
+
+  public double quartile(int i) {
+    sort();
+    switch (i) {
+      case 0:
+        return min();
+      case 1:
+      case 2:
+      case 3:
+        if (n>  100) {
+          return q[i];
+        } else if (n<  2) {
+          throw new IllegalArgumentException("Must have at least two samples to 
estimate quartiles");
+        } else {
+          double x = i * (n - 1) / 4.0;
+          int k = (int) Math.floor(x);
+          double u = x - k;
+          return starter.get(k) * (1 - u) + starter.get(k + 1) * u;
+        }
+      case 4:
+        return max();
+      default:
+        throw new IllegalArgumentException("Quartile number must be in the range 
[0..4] not " + i);
+    }
+  }
+
+  public double median() {
+    return quartile(2);
+  }
+}

Added: 
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java?rev=966469&view=auto
==============================================================================
--- 
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java
 (added)
+++ 
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineSummarizerTest.java
 Thu Jul 22 00:59:03 2010
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import org.apache.mahout.math.jet.random.Gamma;
+import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+import org.junit.Test;
+
+import java.util.Random;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class OnlineSummarizerTest {
+  @Test
+  public void testCount() {
+    OnlineSummarizer x = new OnlineSummarizer();
+    assertEquals(0, x.count());
+    x.add(1);
+    assertEquals(1, x.count());
+
+    for (int i = 1; i<  110; i++) {
+      x.add(i);
+      assertEquals(i, x.count());
+    }
+  }
+
+  @Test
+  public void testStats() {
+    // the reference limits here were derived using a numerical simulation 
where I took
+    // 10,000 samples from the distribution in question and computed the stats 
from that
+    // sample to get min, 25%-ile, median and so on.  I did this 1000 times to 
get 5% and
+    // 95% confidence limits for those values.
+
+    // symmetrical, well behaved
+    check(normal(10000, 1),
+            -4.417246, -3.419809,
+            -0.6972919, -0.6519899,
+            -0.02056658, 0.02176474,
+            0.6503866, 0.6983311,
+            3.419809, 4.417246,
+            -0.01515753, 0.01592942,
+            0.988395, 1.011883);
+
+    // asymmetrical, well behaved.  The range for the maximum was fudged 
slightly to all this to pass.
+    check(exp(10000, 1),
+            4.317969e-06, 3.278763e-04,
+            0.2783866, 0.298,
+            0.6765024, 0.7109463,
+            1.356929, 1.414761,
+            8, 13,
+            0.983805, 1.015920,
+            0.977162, 1.022093
+    );
+
+    // asymmetrical, wacko distribution where mean/median>  10^28
+    check(gamma(10000, 1),
+            0, 0,                                             // minimum
+            1.63067132881301e-60, 6.26363334269806e-58,       // 25th %-ile
+            8.62261497075834e-30, 2.01422505081014e-28,       // median
+            6.70225617733614e-12, 4.44299757853286e-11,       // 75th %-ile
+            238.451174077827, 579.143886928158,               // maximum
+            0.837031762527458, 1.17244066539313,              // mean
+            8.10277696526878, 12.1426255901507);              // standard dev
+  }
+
+  private void check(OnlineSummarizer x, double... values) {
+    for (int i = 0; i<  5; i++) {
+      checkRange(String.format("quartile %d", i), x.quartile(i), values[2 * 
i], values[2 * i + 1]);
+    }
+    assertEquals(x.quartile(2), x.median(), 0);
+
+    checkRange("mean", x.mean(), values[10], values[11]);
+    checkRange("sd", x.sd(), values[12], values[13]);
+  }
+
+  private void checkRange(String msg, double v, double low, double high) {
+    if (v<  low || v>  high) {
+      fail(String.format("Wanted %s to be in range [%f,%f] but got %f", msg, 
low, high, v));
+    }
+  }
+
+  private OnlineSummarizer normal(int n, int seed) {
+    OnlineSummarizer x = new OnlineSummarizer();
+    Random gen = new Random(seed);
+    for (int i = 0; i<  n; i++) {
+      x.add(gen.nextGaussian());
+    }
+    return x;
+  }
+
+  private OnlineSummarizer exp(int n, int seed) {
+    OnlineSummarizer x = new OnlineSummarizer();
+    Random gen = new Random(seed);
+    for (int i = 0; i<  n; i++) {
+      x.add(-Math.log(1 - gen.nextDouble()));
+    }
+    return x;
+  }
+
+  private OnlineSummarizer gamma(int n, int seed) {
+    OnlineSummarizer x = new OnlineSummarizer();
+    Gamma g = new Gamma(0.01, 100, new MersenneTwister(seed));
+    for (int i = 0; i<  n; i++) {
+      x.add(g.nextDouble());
+    }
+    return x;
+  }
+}
+
+




Reply via email to