Author: jeastman
Date: Sat Aug 21 02:21:57 2010
New Revision: 987686
URL: http://svn.apache.org/viewvc?rev=987686&view=rev
Log:
MAHOUT-479: Fixed a bug in TestVectorModelClassifier and messed around with
pdf() in GaussianCluster and ASNModel. Synthetic control seems to work better
on Dirichlet now but I'm troubled by the impact of the two pdf()
implementations on the outcome.
Removed:
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/README.txt
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
Sat Aug 21 02:21:57 2010
@@ -404,12 +404,16 @@ public class DirichletClusterer {
*/
private void emitMostLikelyCluster(VectorWritable vector,
List<DirichletCluster> clusters, Vector pi, Writer writer)
throws IOException {
+ double maxPdf = 0;
+ int clusterId = -1;
for (int i = 0; i < clusters.size(); i++) {
double pdf = pi.get(i);
- if (pdf > threshold && clusters.get(i).getTotalCount() > 0) {
- //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(),
null));
- writer.append(new IntWritable(i), new WeightedVectorWritable(pdf,
vector));
+ if (pdf > maxPdf) {
+ maxPdf = pdf;
+ clusterId = i;
}
}
+ //System.out.println(i + ": " + ClusterBase.formatVector(vector.get(),
null));
+ writer.append(new IntWritable(clusterId), new
WeightedVectorWritable(maxPdf, vector));
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
Sat Aug 21 02:21:57 2010
@@ -26,6 +26,7 @@ import org.apache.mahout.clustering.Abst
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.JsonModelAdapter;
import org.apache.mahout.clustering.Model;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;
@@ -35,27 +36,27 @@ import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
public class AsymmetricSampledNormalModel implements Cluster {
-
- private static final double SQRT2PI = Math.sqrt(2.0 * Math.PI);
- private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {
}.getType();
+
+ private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {
+ }.getType();
private int id;
-
+
// the parameters
private Vector mean;
-
+
private Vector stdDev;
-
+
// the observation statistics, initialized by the first observation
private int s0;
-
+
private Vector s1;
-
+
private Vector s2;
public AsymmetricSampledNormalModel() {
}
-
+
public AsymmetricSampledNormalModel(int id, Vector mean, Vector stdDev) {
this.id = id;
this.mean = mean;
@@ -64,15 +65,15 @@ public class AsymmetricSampledNormalMode
this.s1 = mean.like();
this.s2 = mean.like();
}
-
+
public Vector getMean() {
return mean;
}
-
+
public Vector getStdDev() {
return stdDev;
}
-
+
/**
* Return an instance with the same parameters
*
@@ -81,7 +82,7 @@ public class AsymmetricSampledNormalMode
public AsymmetricSampledNormalModel sampleFromPosterior() {
return new AsymmetricSampledNormalModel(id, mean, stdDev);
}
-
+
@Override
public void observe(VectorWritable v) {
Vector x = v.get();
@@ -97,7 +98,7 @@ public class AsymmetricSampledNormalMode
s2 = s2.plus(x.times(x));
}
}
-
+
@Override
public void computeParameters() {
if (s0 == 0) {
@@ -111,44 +112,29 @@ public class AsymmetricSampledNormalMode
stdDev.assign(Double.MIN_NORMAL);
}
}
-
- /**
- * Calculate a pdf using the supplied sample and stdDev
- *
- * @param x
- * a Vector sample
- * @param sd
- * a double std deviation
- */
- private double pdf(Vector x, double sd) {
- double sd2 = sd * sd;
- double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
- double ex = Math.exp(exp);
- return ex / (sd * SQRT2PI);
- }
-
+
@Override
public double pdf(VectorWritable v) {
Vector x = v.get();
// return the product of the component pdfs
- // TODO: is this reasonable? correct?
- double pdf = pdf(x, stdDev.get(0));
- for (int i = 1; i < x.size(); i++) {
- pdf *= pdf(x, stdDev.get(i));
+ // TODO: is this reasonable? correct? It seems to work in some cases.
+ double pdf = 1;
+ for (int i = 0; i < x.size(); i++) {
+ pdf *= UncommonDistributions.dNorm(x.getQuick(i),
getCenter().getQuick(i), getRadius().getQuick(i));
}
return pdf;
}
-
+
@Override
public int count() {
return s0;
}
-
+
@Override
public String toString() {
return asFormatString(null);
}
-
+
@Override
public String asFormatString(String[] bindings) {
StringBuilder buf = new StringBuilder(50);
@@ -163,7 +149,7 @@ public class AsymmetricSampledNormalMode
buf.append('}');
return buf.toString();
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
this.id = in.readInt();
@@ -178,7 +164,7 @@ public class AsymmetricSampledNormalMode
temp.readFields(in);
this.s2 = temp.get();
}
-
+
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
@@ -188,7 +174,7 @@ public class AsymmetricSampledNormalMode
VectorWritable.writeVector(out, s1);
VectorWritable.writeVector(out, s2);
}
-
+
@Override
public String asJsonString() {
GsonBuilder builder = new GsonBuilder();
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Sat Aug 21 02:21:57 2010
@@ -2,6 +2,7 @@ package org.apache.mahout.clustering.dir
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.clustering.Model;
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
@@ -35,28 +36,17 @@ public class GaussianCluster extends Abs
@Override
public double pdf(VectorWritable vw) {
Vector x = vw.get();
- // return the product of the component pdfs
+ // return the average of the component pdfs
// TODO: is this reasonable? correct?
- double pdf = pdf(x, getRadius().get(0));
- for (int i = 1; i < x.size(); i++) {
- pdf *= pdf(x, getRadius().get(i));
+ double pdf = 0;
+ for (int i = 0; i < x.size(); i++) {
+ double x2 = x.get(i);
+ double m = getCenter().get(i);
+ double s = getRadius().get(i);
+ double dNorm = UncommonDistributions.dNorm(x2, m, s);
+ pdf += dNorm;
}
- return pdf;
- }
-
- /**
- * Calculate a pdf using the supplied sample and stdDev
- *
- * @param x
- * a Vector sample
- * @param sd
- * a double std deviation
- */
- private double pdf(Vector x, double sd) {
- double sd2 = sd * sd;
- double exp = -(x.dot(x) - 2 * x.dot(getCenter()) +
getCenter().dot(getCenter())) / (2 * sd2);
- double ex = Math.exp(exp);
- return ex / (sd * SQRT2PI);
+ return pdf / x.size();
}
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
Sat Aug 21 02:21:57 2010
@@ -6,6 +6,7 @@ import java.util.List;
import org.apache.commons.lang.NotImplementedException;
import org.apache.mahout.classifier.AbstractVectorClassifier;
import org.apache.mahout.clustering.canopy.Canopy;
+import
org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel;
import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
import org.apache.mahout.clustering.kmeans.Cluster;
@@ -93,9 +94,21 @@ public class TestVectorModelClassifier e
models.add(new GaussianCluster(new DenseVector(2).assign(-1), new
DenseVector(2).assign(1), 2));
AbstractVectorClassifier classifier = new VectorModelClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.107, 0.787, 0.107]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.274, 0.452, 0.274]",
AbstractCluster.formatVector(pdf, null));
+ pdf = classifier.classify(new DenseVector(2).assign(2));
+ assertEquals("[2,2]", "[0.806, 0.180, 0.015]",
AbstractCluster.formatVector(pdf, null));
+ }
+
+ public void testASNClusterClassification() {
+ List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
+ models.add(new AsymmetricSampledNormalModel(0, new
DenseVector(2).assign(1), new DenseVector(2).assign(1)));
+ models.add(new AsymmetricSampledNormalModel(1, new DenseVector(2), new
DenseVector(2).assign(1)));
+ models.add(new AsymmetricSampledNormalModel(2, new
DenseVector(2).assign(-1), new DenseVector(2).assign(1)));
+ AbstractVectorClassifier classifier = new VectorModelClassifier(models);
+ Vector pdf = classifier.classify(new DenseVector(2));
+ assertEquals("[0,0]", "[0.212, 0.576, 0.212]",
AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.998, 0.002, 0.000]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.952, 0.047, 0.000]",
AbstractCluster.formatVector(pdf, null));
}
}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
Sat Aug 21 02:21:57 2010
@@ -25,7 +25,7 @@ import java.util.List;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.dirichlet.DirichletClusterer;
-import
org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
+import
org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.VectorWritable;
@@ -66,11 +66,11 @@ public class DisplayDirichlet extends Di
}
protected static void generateResults(ModelDistribution<VectorWritable>
modelDist,
- int numClusters,
- int numIterations,
- double alpha_0,
- int thin,
- int burnin) {
+ int numClusters,
+ int numIterations,
+ double alpha_0,
+ int thin,
+ int burnin) {
DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist,
alpha_0, numClusters, thin, burnin);
List<Cluster[]> result = dc.cluster(numIterations);
printModels(result, burnin);
@@ -87,10 +87,10 @@ public class DisplayDirichlet extends Di
public static void main(String[] args) throws Exception {
VectorWritable modelPrototype = new VectorWritable(new DenseVector(2));
- // ModelDistribution<VectorWritable> modelDist = new
NormalModelDistribution(modelPrototype);
+ //ModelDistribution<VectorWritable> modelDist = new
NormalModelDistribution(modelPrototype);
// ModelDistribution<VectorWritable> modelDist = new
SampledNormalDistribution(modelPrototype);
- // ModelDistribution<VectorWritable> modelDist = new
AsymmetricSampledNormalDistribution(modelPrototype);
- ModelDistribution<VectorWritable> modelDist = new
GaussianClusterDistribution(modelPrototype);
+ ModelDistribution<VectorWritable> modelDist = new
AsymmetricSampledNormalDistribution(modelPrototype);
+ //ModelDistribution<VectorWritable> modelDist = new
GaussianClusterDistribution(modelPrototype);
RandomUtils.useTestSeed();
generateSamples();
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java?rev=987686&r1=987685&r2=987686&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java
Sat Aug 21 02:21:57 2010
@@ -34,7 +34,7 @@ import org.apache.mahout.clustering.diri
import org.apache.mahout.clustering.dirichlet.DirichletDriver;
import org.apache.mahout.clustering.dirichlet.DirichletMapper;
import
org.apache.mahout.clustering.dirichlet.models.AbstractVectorModelDistribution;
-import
org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution;
+import
org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution;
import org.apache.mahout.clustering.syntheticcontrol.Constants;
import org.apache.mahout.clustering.syntheticcontrol.canopy.InputDriver;
@@ -61,8 +61,8 @@ public final class Job extends Dirichlet
log.info("Running with default arguments");
Path output = new Path("output");
HadoopUtil.overwriteOutput(output);
- AbstractVectorModelDistribution modelDistribution = new
DistanceMeasureClusterDistribution(new VectorWritable(new
RandomAccessSparseVector(60)));
- new Job().job(new Path("testdata"), output, modelDistribution, 10, 5,
0.5, 1, false, 0.001);
+ AbstractVectorModelDistribution modelDistribution = new
GaussianClusterDistribution(new VectorWritable(new
RandomAccessSparseVector(60)));
+ new Job().job(new Path("testdata"), output, modelDistribution, 10, 5,
1.0, 1, true, 0);
}
}