Author: srowen
Date: Tue Mar 22 16:15:36 2011
New Revision: 1084234

URL: http://svn.apache.org/viewvc?rev=1084234&view=rev
Log:
MAHOUT-630 weighted average fix and add stddev

Added:
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
      - copied, changed from r1083546, 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
      - copied, changed from r1083546, 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java?rev=1084234&r1=1084233&r2=1084234&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
 Tue Mar 22 16:15:36 2011
@@ -21,7 +21,7 @@ import java.io.Serializable;
 
 import com.google.common.base.Preconditions;
 
-public final class WeightedRunningAverage implements RunningAverage, 
Serializable {
+public class WeightedRunningAverage implements RunningAverage, Serializable {
   
   private double totalWeight;
   private double average;
@@ -42,7 +42,7 @@ public final class WeightedRunningAverag
     if (oldTotalWeight <= 0.0) {
       average = datum * weight;
     } else {
-      average = average * oldTotalWeight / totalWeight + datum / totalWeight;
+      average = average * oldTotalWeight / totalWeight + datum * weight / 
totalWeight;
     }
   }
   
@@ -58,7 +58,7 @@ public final class WeightedRunningAverag
       average = Double.NaN;
       totalWeight = 0.0;
     } else {
-      average = average * oldTotalWeight / totalWeight - datum / totalWeight;
+      average = average * oldTotalWeight / totalWeight - datum * weight / 
totalWeight;
     }
   }
   

Copied: 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
 (from r1083546, 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java)
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java?p2=mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java&p1=mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java&r1=1083546&r2=1084234&rev=1084234&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverage.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageAndStdDev.java
 Tue Mar 22 16:15:36 2011
@@ -17,79 +17,68 @@
 
 package org.apache.mahout.cf.taste.impl.common;
 
-import java.io.Serializable;
+/**
+ * This subclass also provides for a weighted estimate of the sample standard 
deviation.
+ * See <a 
href="http://en.wikipedia.org/wiki/Mean_square_weighted_deviation";>estimate 
formulae here</a>.
+ */
+public final class WeightedRunningAverageAndStdDev extends 
WeightedRunningAverage implements RunningAverageAndStdDev {
 
-import com.google.common.base.Preconditions;
+  private double totalSquaredWeight;
+  private double totalWeightedData;
+  private double totalWeightedSquaredData;
 
-public final class WeightedRunningAverage implements RunningAverage, 
Serializable {
-  
-  private double totalWeight;
-  private double average;
-  
-  public WeightedRunningAverage() {
-    totalWeight = 0.0;
-    average = Double.NaN;
+  public WeightedRunningAverageAndStdDev() {
+    totalSquaredWeight = 0.0;
+    totalWeightedData = 0.0;
+    totalWeightedSquaredData = 0.0;
   }
   
   @Override
-  public synchronized void addDatum(double datum) {
-    addDatum(datum, 1.0);
-  }
-  
   public synchronized void addDatum(double datum, double weight) {
-    double oldTotalWeight = totalWeight;
-    totalWeight += weight;
-    if (oldTotalWeight <= 0.0) {
-      average = datum * weight;
-    } else {
-      average = average * oldTotalWeight / totalWeight + datum / totalWeight;
-    }
+    super.addDatum(datum, weight);
+    totalSquaredWeight += weight * weight;
+    double weightedData = datum * weight;
+    totalWeightedData += weightedData;
+    totalWeightedSquaredData += weightedData * datum;
   }
   
   @Override
-  public synchronized void removeDatum(double datum) {
-    removeDatum(datum, 1.0);
-  }
-  
   public synchronized void removeDatum(double datum, double weight) {
-    double oldTotalWeight = totalWeight;
-    totalWeight -= weight;
-    if (totalWeight <= 0.0) {
-      average = Double.NaN;
-      totalWeight = 0.0;
-    } else {
-      average = average * oldTotalWeight / totalWeight - datum / totalWeight;
+    super.removeDatum(datum, weight);
+    totalSquaredWeight -= weight * weight;
+    if (totalSquaredWeight <= 0.0) {
+      totalSquaredWeight = 0.0;
+    }
+    double weightedData = datum * weight;
+    totalWeightedData -= weightedData;
+    if (totalWeightedData <= 0.0) {
+      totalWeightedData = 0.0;
+    }
+    totalWeightedSquaredData -= weightedData * datum;
+    if (totalWeightedSquaredData <= 0.0) {
+      totalWeightedSquaredData = 0.0;
     }
   }
-  
+
+  /**
+   * @throws UnsupportedOperationException
+   */
   @Override
-  public synchronized void changeDatum(double delta) {
-    changeDatum(delta, 1.0);
-  }
-  
   public synchronized void changeDatum(double delta, double weight) {
-    Preconditions.checkArgument(weight <= totalWeight);
-    average += delta * weight / totalWeight;
-  }
-  
-  public synchronized double getTotalWeight() {
-    return totalWeight;
-  }
-  
-  /** @return {@link #getTotalWeight()} */
-  @Override
-  public synchronized int getCount() {
-    return (int) totalWeight;
+    throw new UnsupportedOperationException();
   }
   
+
   @Override
-  public synchronized double getAverage() {
-    return average;
+  public synchronized double getStandardDeviation() {
+    double totalWeight = getTotalWeight();
+    return Math.sqrt((totalWeightedSquaredData * totalWeight - 
totalWeightedData * totalWeightedData) /
+        (totalWeight * totalWeight - totalSquaredWeight));
   }
   
   @Override
   public synchronized String toString() {
-    return String.valueOf(average);
+    return String.valueOf(String.valueOf(getAverage()) + ',' + 
getStandardDeviation());
   }
-  
+
 }

Copied: 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
 (from r1083546, 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java)
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java?p2=mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java&p1=mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java&r1=1083546&r2=1084234&rev=1084234&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/RunningAverageTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/cf/taste/impl/common/WeightedRunningAverageTest.java
 Tue Mar 22 16:15:36 2011
@@ -20,49 +20,66 @@ package org.apache.mahout.cf.taste.impl.
 import org.apache.mahout.cf.taste.impl.TasteTestCase;
 import org.junit.Test;
 
-/** <p>Tests {@link FullRunningAverage}.</p> */
-public final class RunningAverageTest extends TasteTestCase {
+/**
+ * <p>Tests {@link WeightedRunningAverage} and {@link 
WeightedRunningAverageAndStdDev}.</p>
+ */
+public final class WeightedRunningAverageTest extends TasteTestCase {
 
   @Test
-  public void testFull() {
-    doTestRunningAverage(new FullRunningAverage());
-  }
-
-  @Test
-  public void testCompact() {
-    doTestRunningAverage(new CompactRunningAverage());
-  }
+  public void testWeighted() {
 
-  private static void doTestRunningAverage(RunningAverage runningAverage) {
+    WeightedRunningAverage runningAverage = new WeightedRunningAverage();
 
     assertEquals(0, runningAverage.getCount());
     assertTrue(Double.isNaN(runningAverage.getAverage()));
     runningAverage.addDatum(1.0);
-    assertEquals(1, runningAverage.getCount());
     assertEquals(1.0, runningAverage.getAverage(), EPSILON);
-    runningAverage.addDatum(1.0);
-    assertEquals(2, runningAverage.getCount());
+    runningAverage.addDatum(1.0, 2.0);
     assertEquals(1.0, runningAverage.getAverage(), EPSILON);
-    runningAverage.addDatum(4.0);
-    assertEquals(3, runningAverage.getCount());
+    runningAverage.addDatum(8.0, 0.5);
     assertEquals(2.0, runningAverage.getAverage(), EPSILON);
     runningAverage.addDatum(-4.0);
-    assertEquals(4, runningAverage.getCount());
-    assertEquals(0.5, runningAverage.getAverage(), EPSILON);
+    assertEquals(2.0/3.0, runningAverage.getAverage(), EPSILON);
 
     runningAverage.removeDatum(-4.0);
-    assertEquals(3, runningAverage.getCount());
     assertEquals(2.0, runningAverage.getAverage(), EPSILON);
-    runningAverage.removeDatum(4.0);
-    assertEquals(2, runningAverage.getCount());
-    assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+    runningAverage.removeDatum(2.0, 2.0);
+    assertEquals(2.0, runningAverage.getAverage(), EPSILON);
 
     runningAverage.changeDatum(0.0);
-    assertEquals(2, runningAverage.getCount());
+    assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+    runningAverage.changeDatum(4.0, 0.5);
+    assertEquals(5.0/1.5, runningAverage.getAverage(), EPSILON);
+  }
+
+  @Test
+  public void testWeightedAndStdDev() {
+
+    WeightedRunningAverageAndStdDev runningAverage = new 
WeightedRunningAverageAndStdDev();
+
+    assertEquals(0, runningAverage.getCount());
+    assertTrue(Double.isNaN(runningAverage.getAverage()));
+    assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+
+    runningAverage.addDatum(1.0);
+    assertEquals(1.0, runningAverage.getAverage(), EPSILON);
+    assertTrue(Double.isNaN(runningAverage.getStandardDeviation()));
+    runningAverage.addDatum(1.0, 2.0);
     assertEquals(1.0, runningAverage.getAverage(), EPSILON);
-    runningAverage.changeDatum(2.0);
-    assertEquals(2, runningAverage.getCount());
+    assertEquals(0.0, runningAverage.getStandardDeviation(), EPSILON);
+    runningAverage.addDatum(8.0, 0.5);
+    assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+    assertEquals(Math.sqrt(10.5), runningAverage.getStandardDeviation(), 
EPSILON);
+    runningAverage.addDatum(-4.0);
+    assertEquals(2.0/3.0, runningAverage.getAverage(), EPSILON);
+    assertEquals(Math.sqrt(15.75), runningAverage.getStandardDeviation(), 
EPSILON);
+
+    runningAverage.removeDatum(-4.0);
+    assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+    assertEquals(Math.sqrt(10.5), runningAverage.getStandardDeviation(), 
EPSILON);
+    runningAverage.removeDatum(2.0, 2.0);
     assertEquals(2.0, runningAverage.getAverage(), EPSILON);
+    assertEquals(Math.sqrt(31.5), runningAverage.getStandardDeviation(), 
EPSILON);
   }
 
 }


Reply via email to