Author: jeastman
Date: Sat Aug 21 02:21:57 2010
New Revision: 987686

URL: http://svn.apache.org/viewvc?rev=987686&view=rev
Log:
MAHOUT-479: Fixed a bug in TestVectorModelClassifier and messed around with 
pdf() in GaussianCluster and ASNModel. Synthetic control seems to work better 
on Dirichlet now but I'm troubled by the impact of the two pdf() 
implementations on the outcome.

Removed:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/README.txt
Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
 Sat Aug 21 02:21:57 2010
@@ -404,12 +404,16 @@ public class DirichletClusterer {
    */
   private void emitMostLikelyCluster(VectorWritable vector, 
List<DirichletCluster> clusters, Vector pi, Writer writer)
       throws IOException {
+    double maxPdf = 0;
+    int clusterId = -1;
     for (int i = 0; i < clusters.size(); i++) {
       double pdf = pi.get(i);
-      if (pdf > threshold && clusters.get(i).getTotalCount() > 0) {
-        //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), 
null));
-        writer.append(new IntWritable(i), new WeightedVectorWritable(pdf, 
vector));
+      if (pdf > maxPdf) {
+        maxPdf = pdf;
+        clusterId = i;
       }
     }
+    //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(), 
null));
+    writer.append(new IntWritable(clusterId), new 
WeightedVectorWritable(maxPdf, vector));
   }
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
 Sat Aug 21 02:21:57 2010
@@ -26,6 +26,7 @@ import org.apache.mahout.clustering.Abst
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.JsonModelAdapter;
 import org.apache.mahout.clustering.Model;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.math.function.SquareRootFunction;
@@ -35,27 +36,27 @@ import com.google.gson.GsonBuilder;
 import com.google.gson.reflect.TypeToken;
 
 public class AsymmetricSampledNormalModel implements Cluster {
-  
-  private static final double SQRT2PI = Math.sqrt(2.0 * Math.PI);
-  private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() { 
}.getType();
+
+  private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {
+  }.getType();
 
   private int id;
-  
+
   // the parameters
   private Vector mean;
-  
+
   private Vector stdDev;
-  
+
   // the observation statistics, initialized by the first observation
   private int s0;
-  
+
   private Vector s1;
-  
+
   private Vector s2;
 
   public AsymmetricSampledNormalModel() {
   }
-  
+
   public AsymmetricSampledNormalModel(int id, Vector mean, Vector stdDev) {
     this.id = id;
     this.mean = mean;
@@ -64,15 +65,15 @@ public class AsymmetricSampledNormalMode
     this.s1 = mean.like();
     this.s2 = mean.like();
   }
-  
+
   public Vector getMean() {
     return mean;
   }
-  
+
   public Vector getStdDev() {
     return stdDev;
   }
-  
+
   /**
    * Return an instance with the same parameters
    * 
@@ -81,7 +82,7 @@ public class AsymmetricSampledNormalMode
   public AsymmetricSampledNormalModel sampleFromPosterior() {
     return new AsymmetricSampledNormalModel(id, mean, stdDev);
   }
-  
+
   @Override
   public void observe(VectorWritable v) {
     Vector x = v.get();
@@ -97,7 +98,7 @@ public class AsymmetricSampledNormalMode
       s2 = s2.plus(x.times(x));
     }
   }
-  
+
   @Override
   public void computeParameters() {
     if (s0 == 0) {
@@ -111,44 +112,29 @@ public class AsymmetricSampledNormalMode
       stdDev.assign(Double.MIN_NORMAL);
     }
   }
-  
-  /**
-   * Calculate a pdf using the supplied sample and stdDev
-   * 
-   * @param x
-   *          a Vector sample
-   * @param sd
-   *          a double std deviation
-   */
-  private double pdf(Vector x, double sd) {
-    double sd2 = sd * sd;
-    double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
-    double ex = Math.exp(exp);
-    return ex / (sd * SQRT2PI);
-  }
-  
+
   @Override
   public double pdf(VectorWritable v) {
     Vector x = v.get();
     // return the product of the component pdfs
-    // TODO: is this reasonable? correct?
-    double pdf = pdf(x, stdDev.get(0));
-    for (int i = 1; i < x.size(); i++) {
-      pdf *= pdf(x, stdDev.get(i));
+    // TODO: is this reasonable? correct? It seems to work in some cases.
+    double pdf = 1;
+    for (int i = 0; i < x.size(); i++) {
+      pdf *= UncommonDistributions.dNorm(x.getQuick(i), 
getCenter().getQuick(i), getRadius().getQuick(i));
     }
     return pdf;
   }
-  
+
   @Override
   public int count() {
     return s0;
   }
-  
+
   @Override
   public String toString() {
     return asFormatString(null);
   }
-  
+
   @Override
   public String asFormatString(String[] bindings) {
     StringBuilder buf = new StringBuilder(50);
@@ -163,7 +149,7 @@ public class AsymmetricSampledNormalMode
     buf.append('}');
     return buf.toString();
   }
-  
+
   @Override
   public void readFields(DataInput in) throws IOException {
     this.id = in.readInt();
@@ -178,7 +164,7 @@ public class AsymmetricSampledNormalMode
     temp.readFields(in);
     this.s2 = temp.get();
   }
-  
+
   @Override
   public void write(DataOutput out) throws IOException {
     out.writeInt(id);
@@ -188,7 +174,7 @@ public class AsymmetricSampledNormalMode
     VectorWritable.writeVector(out, s1);
     VectorWritable.writeVector(out, s2);
   }
-  
+
   @Override
   public String asJsonString() {
     GsonBuilder builder = new GsonBuilder();

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
 Sat Aug 21 02:21:57 2010
@@ -2,6 +2,7 @@ package org.apache.mahout.clustering.dir
 
 import org.apache.mahout.clustering.AbstractCluster;
 import org.apache.mahout.clustering.Model;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
@@ -35,28 +36,17 @@ public class GaussianCluster extends Abs
   @Override
   public double pdf(VectorWritable vw) {
     Vector x = vw.get();
-    // return the product of the component pdfs
+    // return the average of the component pdfs
     // TODO: is this reasonable? correct?
-    double pdf = pdf(x, getRadius().get(0));
-    for (int i = 1; i < x.size(); i++) {
-      pdf *= pdf(x, getRadius().get(i));
+    double pdf = 0;
+    for (int i = 0; i < x.size(); i++) {
+      double x2 = x.get(i);
+      double m = getCenter().get(i);
+      double s = getRadius().get(i);
+      double dNorm = UncommonDistributions.dNorm(x2, m, s);
+      pdf += dNorm;
     }
-    return pdf;
-  }
-
-  /**
-   * Calculate a pdf using the supplied sample and stdDev
-   * 
-   * @param x
-   *          a Vector sample
-   * @param sd
-   *          a double std deviation
-   */
-  private double pdf(Vector x, double sd) {
-    double sd2 = sd * sd;
-    double exp = -(x.dot(x) - 2 * x.dot(getCenter()) + 
getCenter().dot(getCenter())) / (2 * sd2);
-    double ex = Math.exp(exp);
-    return ex / (sd * SQRT2PI);
+    return pdf / x.size();
   }
 
 }

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
 Sat Aug 21 02:21:57 2010
@@ -6,6 +6,7 @@ import java.util.List;
 import org.apache.commons.lang.NotImplementedException;
 import org.apache.mahout.classifier.AbstractVectorClassifier;
 import org.apache.mahout.clustering.canopy.Canopy;
+import 
org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
 import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
 import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
 import org.apache.mahout.clustering.kmeans.Cluster;
@@ -93,9 +94,21 @@ public class TestVectorModelClassifier e
     models.add(new GaussianCluster(new DenseVector(2).assign(-1), new 
DenseVector(2).assign(1), 2));
     AbstractVectorClassifier classifier = new VectorModelClassifier(models);
     Vector pdf = classifier.classify(new DenseVector(2));
-    assertEquals("[0,0]", "[0.107, 0.787, 0.107]", 
AbstractCluster.formatVector(pdf, null));
+    assertEquals("[0,0]", "[0.274, 0.452, 0.274]", 
AbstractCluster.formatVector(pdf, null));
+    pdf = classifier.classify(new DenseVector(2).assign(2));
+    assertEquals("[2,2]", "[0.806, 0.180, 0.015]", 
AbstractCluster.formatVector(pdf, null));
+  }
+
+  public void testASNClusterClassification() {
+    List<Model<VectorWritable>> models = new 
ArrayList<Model<VectorWritable>>();
+    models.add(new AsymmetricSampledNormalModel(0, new 
DenseVector(2).assign(1), new DenseVector(2).assign(1)));
+    models.add(new AsymmetricSampledNormalModel(1, new DenseVector(2), new 
DenseVector(2).assign(1)));
+    models.add(new AsymmetricSampledNormalModel(2, new 
DenseVector(2).assign(-1), new DenseVector(2).assign(1)));
+    AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+    Vector pdf = classifier.classify(new DenseVector(2));
+    assertEquals("[0,0]", "[0.212, 0.576, 0.212]", 
AbstractCluster.formatVector(pdf, null));
     pdf = classifier.classify(new DenseVector(2).assign(2));
-    assertEquals("[2,2]", "[0.998, 0.002, 0.000]", 
AbstractCluster.formatVector(pdf, null));
+    assertEquals("[2,2]", "[0.952, 0.047, 0.000]", 
AbstractCluster.formatVector(pdf, null));
   }
 
 }

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
 Sat Aug 21 02:21:57 2010
@@ -25,7 +25,7 @@ import java.util.List;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.ModelDistribution;
 import org.apache.mahout.clustering.dirichlet.DirichletClusterer;
-import 
org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
+import 
org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
 import org.apache.mahout.common.RandomUtils;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.VectorWritable;
@@ -66,11 +66,11 @@ public class DisplayDirichlet extends Di
   }
 
   protected static void generateResults(ModelDistribution<VectorWritable> 
modelDist,
-                                 int numClusters,
-                                 int numIterations,
-                                 double alpha_0,
-                                 int thin,
-                                 int burnin) {
+                                        int numClusters,
+                                        int numIterations,
+                                        double alpha_0,
+                                        int thin,
+                                        int burnin) {
     DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist, 
alpha_0, numClusters, thin, burnin);
     List<Cluster[]> result = dc.cluster(numIterations);
     printModels(result, burnin);
@@ -87,10 +87,10 @@ public class DisplayDirichlet extends Di
 
   public static void main(String[] args) throws Exception {
     VectorWritable modelPrototype = new VectorWritable(new DenseVector(2));
-    // ModelDistribution<VectorWritable> modelDist = new 
NormalModelDistribution(modelPrototype);
+    //ModelDistribution<VectorWritable> modelDist = new 
NormalModelDistribution(modelPrototype);
     // ModelDistribution<VectorWritable> modelDist = new 
SampledNormalDistribution(modelPrototype);
-    // ModelDistribution<VectorWritable> modelDist = new 
AsymmetricSampledNormalDistribution(modelPrototype);
-    ModelDistribution<VectorWritable> modelDist = new 
GaussianClusterDistribution(modelPrototype);
+     ModelDistribution<VectorWritable> modelDist = new 
AsymmetricSampledNormalDistribution(modelPrototype);
+    //ModelDistribution<VectorWritable> modelDist = new 
GaussianClusterDistribution(modelPrototype);
 
     RandomUtils.useTestSeed();
     generateSamples();

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
 Sat Aug 21 02:21:57 2010
@@ -34,7 +34,7 @@ import org.apache.mahout.clustering.diri
 import org.apache.mahout.clustering.dirichlet.DirichletDriver;
 import org.apache.mahout.clustering.dirichlet.DirichletMapper;
 import 
org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
-import 
org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
+import 
org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
 import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
 import org.apache.mahout.clustering.syntheticcontrol.Constants;
 import org.apache.mahout.clustering.syntheticcontrol.canopy.InputDriver;
@@ -61,8 +61,8 @@ public final class Job extends Dirichlet
       log.info("Running with default arguments");
       Path output = new Path("output");
       HadoopUtil.overwriteOutput(output);
-      AbstractVectorModelDistribution modelDistribution = new 
DistanceMeasureClusterDistribution(new VectorWritable(new 
RandomAccessSparseVector(60)));
-      new Job().job(new Path("testdata"), output, modelDistribution, 10, 5, 
0.5, 1, false, 0.001);
+      AbstractVectorModelDistribution modelDistribution = new 
GaussianClusterDistribution(new VectorWritable(new 
RandomAccessSparseVector(60)));
+      new Job().job(new Path("testdata"), output, modelDistribution, 10, 5, 
1.0, 1, true, 0);
     }
   }
 


Reply via email to