Author: jeastman
Date: Thu Sep 23 20:02:40 2010
New Revision: 1000599
URL: http://svn.apache.org/viewvc?rev=1000599&view=rev
Log:
Added small prior to variance in pdf() computations to avoid numeric
instability when it is 0. All tests run
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletClusterer.java
Thu Sep 23 20:02:40 2010
@@ -357,7 +357,8 @@ public class DirichletClusterer {
throws IOException {
Vector pi = new DenseVector(clusters.size());
for (int i = 0; i < clusters.size(); i++) {
- pi.set(i, clusters.get(i).getModel().pdf(vector));
+ double pdf = clusters.get(i).getModel().pdf(vector);
+ pi.set(i, pdf);
}
pi = pi.divide(pi.zSum());
if (emitMostLikely) {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
Thu Sep 23 20:02:40 2010
@@ -120,7 +120,8 @@ public class AsymmetricSampledNormalMode
// TODO: is this reasonable? correct? It seems to work in some cases.
double pdf = 1;
for (int i = 0; i < x.size(); i++) {
- pdf *= UncommonDistributions.dNorm(x.getQuick(i),
getCenter().getQuick(i), getRadius().getQuick(i));
+ // small prior on stdDev to avoid numeric instability when stdDev==0
+ pdf *= UncommonDistributions.dNorm(x.getQuick(i), mean.getQuick(i),
stdDev.getQuick(i) + 0.000001);
}
return pdf;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Thu Sep 23 20:02:40 2010
@@ -38,7 +38,8 @@ public class GaussianCluster extends Abs
for (int i = 0; i < x.size(); i++) {
double x2 = x.get(i);
double m = getCenter().get(i);
- double s = getRadius().get(i);
+ // small prior on s to avoid numeric instability when s==0
+ double s = getRadius().get(i) + 0.000001;
double dNorm = UncommonDistributions.dNorm(x2, m, s);
pdf += dNorm;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
Thu Sep 23 20:02:40 2010
@@ -36,27 +36,29 @@ import com.google.gson.GsonBuilder;
import com.google.gson.reflect.TypeToken;
public class NormalModel implements Cluster {
-
+
private static final double SQRT2PI = Math.sqrt(2.0 * Math.PI);
- private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>()
{}.getType();
+ private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {
+ }.getType();
private int id;
-
+
// the parameters
private Vector mean;
-
+
private double stdDev;
-
+
// the observation statistics, initialized by the first observation
private int s0;
-
+
private Vector s1;
-
+
private Vector s2;
- public NormalModel() { }
-
+ public NormalModel() {
+ }
+
public NormalModel(int id, Vector mean, double stdDev) {
this.id = id;
this.mean = mean;
@@ -65,19 +67,19 @@ public class NormalModel implements Clus
this.s1 = mean.like();
this.s2 = mean.like();
}
-
+
int getS0() {
return s0;
}
-
+
public Vector getMean() {
return mean;
}
-
+
public double getStdDev() {
return stdDev;
}
-
+
/**
* TODO: Return a proper sample from the posterior. For now, return an
instance with the same parameters
*
@@ -86,7 +88,7 @@ public class NormalModel implements Clus
public NormalModel sampleFromPosterior() {
return new NormalModel(id, mean, stdDev);
}
-
+
@Override
public void observe(VectorWritable x) {
s0++;
@@ -102,7 +104,7 @@ public class NormalModel implements Clus
s2 = s2.plus(v.times(v));
}
}
-
+
@Override
public void computeParameters() {
if (s0 == 0) {
@@ -117,26 +119,28 @@ public class NormalModel implements Clus
stdDev = Double.MIN_VALUE;
}
}
-
+
@Override
public double pdf(VectorWritable v) {
Vector x = v.get();
- double sd2 = stdDev * stdDev;
+ // small prior on std to avoid numeric instability when std==0
+ double std = stdDev + 0.000001;
+ double sd2 = std * std;
double exp = -(x.dot(x) - 2 * x.dot(mean) + mean.dot(mean)) / (2 * sd2);
double ex = Math.exp(exp);
- return ex / (stdDev * SQRT2PI);
+ return ex / (std * SQRT2PI);
}
-
+
@Override
public int count() {
return s0;
}
-
+
@Override
public String toString() {
return asFormatString(null);
}
-
+
@Override
public String asFormatString(String[] bindings) {
StringBuilder buf = new StringBuilder();
@@ -147,7 +151,7 @@ public class NormalModel implements Clus
buf.append(" sd=").append(String.format(Locale.ENGLISH, "%.2f",
stdDev)).append('}');
return buf.toString();
}
-
+
@Override
public void readFields(DataInput in) throws IOException {
this.id = in.readInt();
@@ -161,7 +165,7 @@ public class NormalModel implements Clus
temp.readFields(in);
this.s2 = temp.get();
}
-
+
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
@@ -171,7 +175,7 @@ public class NormalModel implements Clus
VectorWritable.writeVector(out, s1);
VectorWritable.writeVector(out, s2);
}
-
+
@Override
public String asJsonString() {
GsonBuilder builder = new GsonBuilder();
Modified:
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java?rev=1000599&r1=1000598&r2=1000599&view=diff
==============================================================================
---
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
(original)
+++
mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwMapper.java
Thu Sep 23 20:02:40 2010
@@ -59,6 +59,9 @@ public class CDbwMapper extends Mapper<I
WeightedVectorWritable currentMDP = mostDistantPoints.get(key);
List<VectorWritable> refPoints = representativePoints.get(key);
+ if (refPoints == null){
+ System.out.println();
+ }
double totalDistance = 0.0;
for (VectorWritable refPoint : refPoints) {
totalDistance += measure.distance(refPoint.get(), point.getVector());