Author: jeastman
Date: Mon Apr 18 04:19:01 2011
New Revision: 1094222
URL: http://svn.apache.org/viewvc?rev=1094222&view=rev
Log:
MAHOUT-479: Implemented ClusterClassifier, ClusterPolicy(s) and
ClusterIterator which can duplicate k-means and Dirichlet clustering in
sequential execution only. Added unit tests and switched DisplayKMeans and
DisplayDirichlet to use the ClusterIterator. Gives a pretty good sanity check
of the clustering.
Changed pdf() implementation of GaussianCluster to use the product of the
component pdfs vs the average. Seems to work a lot better.
Deprecated a number of old Dirichlet models and the experimental
VectorModelClassifier.
Changed the type of AbstractCluster numPoints to long from int
All unit tests run
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
(with props)
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
(with props)
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java
(with props)
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java
(with props)
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java
(with props)
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Model.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/AbstractCluster.java
Mon Apr 18 04:19:01 2011
@@ -38,7 +38,7 @@ public abstract class AbstractCluster im
// cluster persistent state
private int id;
- private int numPoints;
+ private long numPoints;
private Vector center;
@@ -84,10 +84,10 @@ public abstract class AbstractCluster im
}
/**
- * @param numPoints the numPoints to set
+ * @param l the numPoints to set
*/
- protected void setNumPoints(int numPoints) {
- this.numPoints = numPoints;
+ protected void setNumPoints(long l) {
+ this.numPoints = l;
}
/**
@@ -172,7 +172,7 @@ public abstract class AbstractCluster im
}
@Override
- public int getNumPoints() {
+ public long getNumPoints() {
return numPoints;
}
@@ -199,7 +199,7 @@ public abstract class AbstractCluster im
@Override
public void readFields(DataInput in) throws IOException {
this.id = in.readInt();
- this.numPoints = in.readInt();
+ this.numPoints = in.readLong();
VectorWritable temp = new VectorWritable();
temp.readFields(in);
this.center = temp.get();
@@ -210,7 +210,7 @@ public abstract class AbstractCluster im
@Override
public void write(DataOutput out) throws IOException {
out.writeInt(id);
- out.writeInt(numPoints);
+ out.writeLong(numPoints);
VectorWritable.writeVector(out, center);
VectorWritable.writeVector(out, radius);
}
@@ -301,7 +301,7 @@ public abstract class AbstractCluster im
}
@Override
- public int count() {
+ public long count() {
return getNumPoints();
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Cluster.java
Mon Apr 18 04:19:01 2011
@@ -61,7 +61,7 @@ public interface Cluster extends Model<V
* Get an integer denoting the number of points observed by this cluster
* @return an integer
*/
- int getNumPoints();
+ long getNumPoints();
/**
* Produce a custom, human-friendly, printable representation of the Cluster.
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java?rev=1094222&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
Mon Apr 18 04:19:01 2011
@@ -0,0 +1,155 @@
+/* Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.List;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer;
+import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.math.function.TimesFunction;
+
+/**
+ * This classifier works with any clustering Cluster. It is initialized with a
+ * list of compatible clusters and thereafter it can classify any new Vector
+ * into one or more of the clusters based upon the pdf() function which each
+ * cluster supports.
+ *
+ * In addition, it is an OnlineLearner and can be trained. Training amounts to
+ * asking the actual model to observe the vector and closing the classifier
+ * causes all the models to computeParameters.
+ */
+public class ClusterClassifier extends AbstractVectorClassifier implements
+ OnlineLearner, Writable {
+
+ private List<Cluster> models;
+
+ private String modelClass;
+
+ /**
+ * The public constructor accepts a list of clusters to become the models
+ *
+ * @param models
+ * a List<Cluster>
+ */
+ public ClusterClassifier(List<Cluster> models) {
+ this.models = models;
+ modelClass = models.get(0).getClass().getName();
+ }
+
+ // needed for serialization/deserialization
+ public ClusterClassifier() {}
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector pdfs = new DenseVector(getModels().size());
+ if (getModels().get(0) instanceof SoftCluster) {
+ Collection<SoftCluster> clusters = new ArrayList<SoftCluster>();
+ List<Double> distances = new ArrayList<Double>();
+ for (Cluster model : getModels()) {
+ SoftCluster sc = (SoftCluster) model;
+ clusters.add(sc);
+ distances.add(sc.getMeasure().distance(instance, sc.getCenter()));
+ }
+ return new FuzzyKMeansClusterer().computePi(clusters, distances);
+ } else {
+ int i = 0;
+ for (Cluster model : getModels()) {
+ pdfs.set(i++, model.pdf(new VectorWritable(instance)));
+ }
+ return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum());
+ }
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ if (getModels().size() == 2) {
+ double pdf0 = getModels().get(0).pdf(new VectorWritable(instance));
+ double pdf1 = getModels().get(1).pdf(new VectorWritable(instance));
+ return pdf0 / (pdf0 + pdf1);
+ }
+ throw new IllegalStateException();
+ }
+
+ @Override
+ public int numCategories() {
+ return getModels().size();
+ }
+
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(getModels().size());
+ out.writeUTF(modelClass);
+ for (Cluster cluster : getModels()) {
+ cluster.write(out);
+ }
+ }
+
+ public void readFields(DataInput in) throws IOException {
+ int size = in.readInt();
+ modelClass = in.readUTF();
+ ClassLoader ccl = Thread.currentThread().getContextClassLoader();
+ try {
+ Class<? extends Cluster> factory = ccl.loadClass(modelClass).asSubclass(
+ Cluster.class);
+
+ models = new ArrayList<Cluster>();
+ for (int i = 0; i < size; i++) {
+ Cluster element = factory.newInstance();
+ element.readFields(in);
+ getModels().add(element);
+ }
+ } catch (ClassNotFoundException e) {
+ throw new IllegalStateException(e);
+ } catch (InstantiationException e) {
+ throw new IllegalStateException(e);
+ } catch (IllegalAccessException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ public void train(int actual, Vector instance) {
+ getModels().get(actual).observe(new VectorWritable(instance));
+ }
+
+ public void train(long trackingKey, String groupKey, int actual,
+ Vector instance) {
+ getModels().get(actual).observe(new VectorWritable(instance));
+ }
+
+ public void train(long trackingKey, int actual, Vector instance) {
+ getModels().get(actual).observe(new VectorWritable(instance));
+ }
+
+ public void close() {
+ for (Cluster cluster : getModels()) {
+ cluster.computeParameters();
+ }
+ }
+
+ public List<Cluster> getModels() {
+ return models;
+ }
+}
Propchange:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java?rev=1094222&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
Mon Apr 18 04:19:01 2011
@@ -0,0 +1,70 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.util.List;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * This is an experimental clustering iterator which works with a
+ * ClusteringPolicy and a prior ClusterClassifier which has been initialized
+ * with a set of models. To date, it has been tested with k-means and Dirichlet
+ * clustering. See examples DisplayKMeans and DisplayDirichlet which have been
+ * switched over to use it.
+ *
+ */
+public class ClusterIterator {
+
+ public ClusterIterator(ClusteringPolicy policy) {
+ super();
+ this.policy = policy;
+ }
+
+ private ClusteringPolicy policy;
+
+ /**
+ * Iterate over data using a prior-trained ClusterClassifier, for a number of
+ * iterations
+ *
+ * @param data
+ * a List<Vector> of input vectors
+ * @param prior
+ * the prior-trained ClusterClassifier
+ * @param numIterations
+ * the int number of iterations to perform
+ * @return the posterior ClusterClassifier
+ */
+ public ClusterClassifier iterate(List<Vector> data, ClusterClassifier prior,
+ int numIterations) {
+ for (int iteration = 1; iteration <= numIterations; iteration++) {
+ for (Vector vector : data) {
+ // classification yields probabilities
+ Vector pdfs = prior.classify(vector);
+ // policy selects a model given those probabilities
+ int selected = policy.select(pdfs);
+ // training causes all models to observe data
+ prior.train(selected, vector);
+ }
+ // compute the posterior models
+ prior.close();
+ // update the policy
+ policy.update(prior);
+ }
+ return prior;
+ }
+}
Propchange:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java?rev=1094222&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java
Mon Apr 18 04:19:01 2011
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * A ClusteringPolicy captures the semantics of assignment of points to
clusters
+ *
+ */
+public interface ClusteringPolicy {
+
+ /**
+ * Return the index of the most appropriate model
+ *
+ * @param pdfs
+ * a Vector of pdfs
+ * @return an int index
+ */
+ public abstract int select(Vector pdfs);
+
+ /**
+ * Update the policy with the given classifier
+ *
+ * @param posterior
+ * a ClusterClassifier
+ */
+ public abstract void update(ClusterClassifier posterior);
+
+}
\ No newline at end of file
Propchange:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusteringPolicy.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java?rev=1094222&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java
Mon Apr 18 04:19:01 2011
@@ -0,0 +1,54 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.clustering.dirichlet.UncommonDistributions;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+
+public class DirichletClusteringPolicy implements ClusteringPolicy {
+
+ public DirichletClusteringPolicy(int k, double alpha0) {
+ super();
+ this.totalCounts = new DenseVector(k);
+ this.alpha0 = alpha0;
+ this.mixture = UncommonDistributions.rDirichlet(totalCounts, alpha0);
+ }
+
+ // The mixture is the Dirichlet distribution of the total Cluster counts over
+ // all iterations
+ private Vector mixture;
+
+ // Alpha_0 primes the Dirichlet distribution
+ private double alpha0;
+
+ // Total observed over all time
+ private Vector totalCounts;
+
+ public int select(Vector pdfs) {
+ return UncommonDistributions.rMultinom(pdfs.times(mixture));
+ }
+
+ // update the total counts and then the mixture
+ public void update(ClusterClassifier prior) {
+ for (int i = 0; i < totalCounts.size(); i++) {
+ long nObserved = prior.getModels().get(i).getNumPoints();
+ totalCounts.set(i, totalCounts.get(i) + nObserved);
+ }
+ mixture = UncommonDistributions.rDirichlet(totalCounts, alpha0);
+ }
+}
Propchange:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/DirichletClusteringPolicy.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java?rev=1094222&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java
Mon Apr 18 04:19:01 2011
@@ -0,0 +1,44 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import org.apache.mahout.math.Vector;
+
+/**
+ * This is a simple maximum likelihood clustering policy, suitable for k-means
+ * clustering
+ *
+ */
+public class KMeansClusteringPolicy implements ClusteringPolicy {
+
+ /* (non-Javadoc)
+ * @see
org.apache.mahout.clustering.ClusteringPolicy#select(org.apache.mahout.math.Vector)
+ */
+ @Override
+ public int select(Vector pdfs) {
+ return pdfs.maxValueIndex();
+ }
+
+ /* (non-Javadoc)
+ * @see
org.apache.mahout.clustering.ClusteringPolicy#update(org.apache.mahout.clustering.ClusterClassifier)
+ */
+ @Override
+ public void update(ClusterClassifier posterior) {
+ // nothing to do here
+ }
+
+}
Propchange:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/KMeansClusteringPolicy.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Model.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Model.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Model.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/Model.java Mon
Apr 18 04:19:01 2011
@@ -55,7 +55,7 @@ public interface Model<O> extends Writab
*
* @return an int
*/
- int count();
+ long count();
/**
* @return a sample of my posterior model
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/VectorModelClassifier.java
Mon Apr 18 04:19:01 2011
@@ -31,6 +31,7 @@ import org.apache.mahout.math.function.T
* This classifier works with any of the clustering Models. It is initialized
with
* a list of compatible Models and thereafter it can classify any new Vector
into
* one or more of the Models based upon the pdf() function which each Model
supports.
+ * @deprecated in favor of ClusterClassifier
*/
public class VectorModelClassifier extends AbstractVectorClassifier {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletCluster.java
Mon Apr 18 04:19:01 2011
@@ -129,7 +129,7 @@ public class DirichletCluster implements
}
@Override
- public int getNumPoints() {
+ public long getNumPoints() {
return model.getNumPoints();
}
@@ -144,7 +144,7 @@ public class DirichletCluster implements
}
@Override
- public int count() {
+ public long count() {
return model.count();
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalDistribution.java
Mon Apr 18 04:19:01 2011
@@ -26,6 +26,7 @@ import org.apache.mahout.math.VectorWrit
* An implementation of the ModelDistribution interface suitable for testing
the DirichletCluster algorithm.
* Uses a Normal Distribution to sample the prior model values. Model values
have a vector standard deviation,
* allowing assymetrical regions to be covered by a model.
+ * @deprecated use GaussianClusterDistribution instead
*/
public class AsymmetricSampledNormalDistribution extends
AbstractVectorModelDistribution {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java
Mon Apr 18 04:19:01 2011
@@ -32,6 +32,10 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;
+/**
+ *
+ *@deprecated Use GaussianCluster instead
+ */
public class AsymmetricSampledNormalModel implements Cluster {
private int id;
@@ -137,7 +141,7 @@ public class AsymmetricSampledNormalMode
}
@Override
- public int count() {
+ public long count() {
return s0;
}
@@ -197,7 +201,7 @@ public class AsymmetricSampledNormalMode
}
@Override
- public int getNumPoints() {
+ public long getNumPoints() {
return s0;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/GaussianCluster.java
Mon Apr 18 04:19:01 2011
@@ -24,43 +24,39 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
public class GaussianCluster extends AbstractCluster {
-
- public GaussianCluster() {
- }
-
+
+ public GaussianCluster() {}
+
public GaussianCluster(Vector point, int id2) {
super(point, id2);
}
-
+
public GaussianCluster(Vector center, Vector radius, int id) {
super(center, radius, id);
}
-
+
@Override
public String getIdentifier() {
return "GC:" + getId();
}
-
+
@Override
public Model<VectorWritable> sampleFromPosterior() {
return new GaussianCluster(getCenter(), getRadius(), getId());
}
-
+
@Override
public double pdf(VectorWritable vw) {
Vector x = vw.get();
- // return the average of the component pdfs
- // TODO: is this reasonable? correct?
- double pdf = 0;
+ // return the product of the component pdfs
+ // TODO: is this reasonable? correct? It seems to work in some cases.
+ double pdf = 1;
for (int i = 0; i < x.size(); i++) {
- double x2 = x.get(i);
- double m = getCenter().get(i);
- // small prior on s to avoid numeric instability when s==0
- double s = getRadius().get(i) + 0.000001;
- double dNorm = UncommonDistributions.dNorm(x2, m, s);
- pdf += dNorm;
+ // small prior on stdDev to avoid numeric instability when stdDev==0
+ pdf *= UncommonDistributions.dNorm(x.getQuick(i),
+ getCenter().getQuick(i), getRadius().getQuick(i) + 0.000001);
}
- return pdf / x.size();
+ return pdf;
}
-
+
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1Model.java
Mon Apr 18 04:19:01 2011
@@ -31,6 +31,10 @@ import org.apache.mahout.common.paramete
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+/**
+ *
+ *@deprecated use DistanceMeasureCluster instead
+ */
public class L1Model implements Cluster {
private static final DistanceMeasure MEASURE = new
ManhattanDistanceMeasure();
@@ -73,7 +77,7 @@ public class L1Model implements Cluster
}
@Override
- public int count() {
+ public long count() {
return counter;
}
@@ -137,7 +141,7 @@ public class L1Model implements Cluster
}
@Override
- public int getNumPoints() {
+ public long getNumPoints() {
return counter;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/L1ModelDistribution.java
Mon Apr 18 04:19:01 2011
@@ -24,6 +24,7 @@ import org.apache.mahout.math.VectorWrit
/**
* An implementation of the ModelDistribution interface suitable for testing
the DirichletCluster algorithm.
* Uses a L1Distribution
+ * @deprecated Use DistanceMeasureClusterDistribution instead
*/
public class L1ModelDistribution extends AbstractVectorModelDistribution {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java
Mon Apr 18 04:19:01 2011
@@ -32,6 +32,10 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.apache.mahout.math.function.SquareRootFunction;
+/**
+ *
+ *@deprecated use GaussianCluster instead
+ */
public class NormalModel implements Cluster {
private static final double SQRT2PI = Math.sqrt(2.0 * Math.PI);
@@ -142,7 +146,7 @@ public class NormalModel implements Clus
}
@Override
- public int count() {
+ public long count() {
return s0;
}
@@ -197,7 +201,7 @@ public class NormalModel implements Clus
}
@Override
- public int getNumPoints() {
+ public long getNumPoints() {
return s0;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModelDistribution.java
Mon Apr 18 04:19:01 2011
@@ -24,6 +24,7 @@ import org.apache.mahout.math.VectorWrit
/**
* An implementation of the ModelDistribution interface suitable for testing
the DirichletCluster algorithm.
* Uses a Normal Distribution
+ * @deprecated Use GaussianClusterDistribution instead
*/
public class NormalModelDistribution extends AbstractVectorModelDistribution {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalDistribution.java
Mon Apr 18 04:19:01 2011
@@ -25,6 +25,7 @@ import org.apache.mahout.math.VectorWrit
/**
* An implementation of the ModelDistribution interface suitable for testing
the DirichletCluster algorithm.
* Uses a Normal Distribution to sample the prior model values.
+ * @deprecated Use GaussianClusterDistribution instead
*/
public class SampledNormalDistribution extends NormalModelDistribution {
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/SampledNormalModel.java
Mon Apr 18 04:19:01 2011
@@ -22,6 +22,10 @@ import java.util.Locale;
import org.apache.mahout.clustering.AbstractCluster;
import org.apache.mahout.math.Vector;
+/**
+ *
+ *@deprecated Use GaussianCluster instead
+ */
public class SampledNormalModel extends NormalModel {
public SampledNormalModel() {
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestVectorModelClassifier.java
Mon Apr 18 04:19:01 2011
@@ -37,21 +37,25 @@ import org.apache.mahout.math.VectorWrit
import org.junit.Test;
public final class TestVectorModelClassifier extends MahoutTestCase {
-
+
@Test
public void testDMClusterClassification() {
List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
DistanceMeasure measure = new ManhattanDistanceMeasure();
- models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0,
measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0,
+ measure));
models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure));
- models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2,
measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2,
+ measure));
AbstractVectorClassifier classifier = new VectorModelClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.107, 0.787, 0.107]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.107, 0.787, 0.107]",
+ AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.867, 0.117, 0.016]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.867, 0.117, 0.016]",
+ AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testCanopyClassification() {
List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
@@ -61,11 +65,13 @@ public final class TestVectorModelClassi
models.add(new Canopy(new DenseVector(2).assign(-1), 2, measure));
AbstractVectorClassifier classifier = new VectorModelClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.107, 0.787, 0.107]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.107, 0.787, 0.107]",
+ AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.867, 0.117, 0.016]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.867, 0.117, 0.016]",
+ AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testClusterClassification() {
List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
@@ -75,11 +81,13 @@ public final class TestVectorModelClassi
models.add(new Cluster(new DenseVector(2).assign(-1), 2, measure));
AbstractVectorClassifier classifier = new VectorModelClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.107, 0.787, 0.107]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.107, 0.787, 0.107]",
+ AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.867, 0.117, 0.016]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.867, 0.117, 0.016]",
+ AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testMSCanopyClassification() {
List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
@@ -91,10 +99,9 @@ public final class TestVectorModelClassi
try {
classifier.classify(new DenseVector(2));
fail("Expected NotImplementedException");
- } catch (NotImplementedException e) {
- }
+ } catch (NotImplementedException e) {}
}
-
+
@Test
public void testSoftClusterClassification() {
List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
@@ -104,35 +111,47 @@ public final class TestVectorModelClassi
models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure));
AbstractVectorClassifier classifier = new VectorModelClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.000, 1.000, 0.000]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.000, 1.000, 0.000]",
+ AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.735, 0.184, 0.082]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.735, 0.184, 0.082]",
+ AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testGaussianClusterClassification() {
List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
- models.add(new GaussianCluster(new DenseVector(2).assign(1), new
DenseVector(2).assign(1), 0));
- models.add(new GaussianCluster(new DenseVector(2), new
DenseVector(2).assign(1), 1));
- models.add(new GaussianCluster(new DenseVector(2).assign(-1), new
DenseVector(2).assign(1), 2));
+ models.add(new GaussianCluster(new DenseVector(2).assign(1),
+ new DenseVector(2).assign(1), 0));
+ models.add(new GaussianCluster(new DenseVector(2), new DenseVector(2)
+ .assign(1), 1));
+ models.add(new GaussianCluster(new DenseVector(2).assign(-1),
+ new DenseVector(2).assign(1), 2));
AbstractVectorClassifier classifier = new VectorModelClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.274, 0.452, 0.274]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.212, 0.576, 0.212]",
+ AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.806, 0.180, 0.015]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.952, 0.047, 0.000]",
+ AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testASNClusterClassification() {
List<Model<VectorWritable>> models = new
ArrayList<Model<VectorWritable>>();
- models.add(new AsymmetricSampledNormalModel(0, new
DenseVector(2).assign(1), new DenseVector(2).assign(1)));
- models.add(new AsymmetricSampledNormalModel(1, new DenseVector(2), new
DenseVector(2).assign(1)));
- models.add(new AsymmetricSampledNormalModel(2, new
DenseVector(2).assign(-1), new DenseVector(2).assign(1)));
+ models.add(new AsymmetricSampledNormalModel(0,
+ new DenseVector(2).assign(1), new DenseVector(2).assign(1)));
+ models.add(new AsymmetricSampledNormalModel(1, new DenseVector(2),
+ new DenseVector(2).assign(1)));
+ models.add(new AsymmetricSampledNormalModel(2, new DenseVector(2)
+ .assign(-1), new DenseVector(2).assign(1)));
AbstractVectorClassifier classifier = new VectorModelClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.212, 0.576, 0.212]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.212, 0.576, 0.212]",
+ AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.952, 0.047, 0.000]",
AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.952, 0.047, 0.000]",
+ AbstractCluster.formatVector(pdf, null));
}
-
+
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/meanshift/TestMeanShift.java
Mon Apr 18 04:19:01 2011
@@ -293,8 +293,8 @@ public final class TestMeanShift extends
assertEquals("values", 1, values.size());
MeanShiftCanopy reducerCanopy = values.get(0);
assertEquals("ids", refCanopy.getId(), reducerCanopy.getId());
- int refNumPoints = refCanopy.getNumPoints();
- int reducerNumPoints = reducerCanopy.getNumPoints();
+ long refNumPoints = refCanopy.getNumPoints();
+ long reducerNumPoints = reducerCanopy.getNumPoints();
assertEquals("numPoints", refNumPoints, reducerNumPoints);
String refCenter = refCanopy.getCenter().asFormatString();
String reducerCenter = reducerCanopy.getCenter().asFormatString();
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java
Mon Apr 18 04:19:01 2011
@@ -20,35 +20,43 @@ package org.apache.mahout.clustering.dis
import java.awt.Graphics;
import java.awt.Graphics2D;
import java.util.ArrayList;
+import java.util.Iterator;
import java.util.List;
import org.apache.mahout.clustering.Cluster;
+import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.ClusterIterator;
+import org.apache.mahout.clustering.ClusteringPolicy;
+import org.apache.mahout.clustering.DirichletClusteringPolicy;
+import org.apache.mahout.clustering.Model;
import org.apache.mahout.clustering.ModelDistribution;
import org.apache.mahout.clustering.dirichlet.DirichletClusterer;
-import
org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution;
+import
org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class DisplayDirichlet extends DisplayClustering {
-
- private static final Logger log =
LoggerFactory.getLogger(DisplayDirichlet.class);
-
+
+ private static final Logger log = LoggerFactory
+ .getLogger(DisplayDirichlet.class);
+
public DisplayDirichlet() {
initialize();
this.setTitle("Dirichlet Process Clusters - Normal Distribution (>"
+ (int) (significance * 100) + "% of population)");
}
-
+
// Override the paint() method
@Override
public void paint(Graphics g) {
plotSampleData((Graphics2D) g);
plotClusters((Graphics2D) g);
}
-
+
protected static void printModels(Iterable<Cluster[]> result, int
significant) {
int row = 0;
StringBuilder models = new StringBuilder(100);
@@ -57,7 +65,8 @@ public class DisplayDirichlet extends Di
for (int k = 0; k < r.length; k++) {
Cluster model = r[k];
if (model.count() > significant) {
-
models.append('m').append(k).append(model.asFormatString(null)).append(", ");
+ models.append('m').append(k).append(model.asFormatString(null))
+ .append(", ");
}
}
models.append('\n');
@@ -65,34 +74,53 @@ public class DisplayDirichlet extends Di
models.append('\n');
log.info(models.toString());
}
-
- protected static void generateResults(ModelDistribution<VectorWritable>
modelDist,
- int numClusters,
- int numIterations,
- double alpha0,
- int thin,
- int burnin) {
- DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist,
alpha0, numClusters, thin, burnin);
- List<Cluster[]> result = dc.cluster(numIterations);
- printModels(result, burnin);
- for (Cluster[] models : result) {
- List<Cluster> clusters = new ArrayList<Cluster>();
- for (Cluster cluster : models) {
- if (isSignificant(cluster)) {
- clusters.add(cluster);
+
+ protected static void generateResults(
+ ModelDistribution<VectorWritable> modelDist, int numClusters,
+ int numIterations, double alpha0, int thin, int burnin) {
+ boolean b = false;
+ if (b) {
+ DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist,
+ alpha0, numClusters, thin, burnin);
+ List<Cluster[]> result = dc.cluster(numIterations);
+ printModels(result, burnin);
+ for (Cluster[] models : result) {
+ List<Cluster> clusters = new ArrayList<Cluster>();
+ for (Cluster cluster : models) {
+ if (isSignificant(cluster)) {
+ clusters.add(cluster);
+ }
}
+ CLUSTERS.add(clusters);
+ }
+ } else {
+ List<Vector> points = new ArrayList<Vector>();
+ for (VectorWritable sample : SAMPLE_DATA) {
+ points.add(sample.get());
}
- CLUSTERS.add(clusters);
+ ClusteringPolicy policy = new DirichletClusteringPolicy(numClusters,
+ numIterations);
+ List<Cluster> models = new ArrayList<Cluster>();
+ for (Model<VectorWritable> cluster : modelDist
+ .sampleFromPrior(numClusters)) {
+ models.add((Cluster) cluster);
+ }
+ ClusterClassifier prior = new ClusterClassifier(models);
+ ClusterIterator iterator = new ClusterIterator(policy);
+ ClusterClassifier posterior = iterator.iterate(points, prior, 5);
+ List<Cluster> models2 = posterior.getModels();
+ for (Iterator<Cluster> it = models2.iterator(); it.hasNext();) {
+ if (!isSignificant(it.next())) it.remove();
+ }
+ CLUSTERS.add(models2);
}
}
-
+
public static void main(String[] args) throws Exception {
VectorWritable modelPrototype = new VectorWritable(new DenseVector(2));
- //ModelDistribution<VectorWritable> modelDist = new
NormalModelDistribution(modelPrototype);
- // ModelDistribution<VectorWritable> modelDist = new
SampledNormalDistribution(modelPrototype);
- ModelDistribution<VectorWritable> modelDist = new
AsymmetricSampledNormalDistribution(modelPrototype);
- //ModelDistribution<VectorWritable> modelDist = new
GaussianClusterDistribution(modelPrototype);
-
+ ModelDistribution<VectorWritable> modelDist = new
GaussianClusterDistribution(
+ modelPrototype);
+
RandomUtils.useTestSeed();
generateSamples();
int numIterations = 40;
@@ -103,5 +131,5 @@ public class DisplayDirichlet extends Di
generateResults(modelDist, numClusters, numIterations, alpha0, thin,
burnin);
new DisplayDirichlet();
}
-
+
}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java
Mon Apr 18 04:19:01 2011
@@ -19,25 +19,35 @@ package org.apache.mahout.clustering.dis
import java.awt.Graphics;
import java.awt.Graphics2D;
+import java.util.ArrayList;
+import java.util.List;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
+import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.ClusterIterator;
+import org.apache.mahout.clustering.ClusteringPolicy;
+import org.apache.mahout.clustering.KMeansClusteringPolicy;
+import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.kmeans.KMeansDriver;
import org.apache.mahout.clustering.kmeans.RandomSeedGenerator;
import org.apache.mahout.common.HadoopUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.common.distance.DistanceMeasure;
import org.apache.mahout.common.distance.ManhattanDistanceMeasure;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
class DisplayKMeans extends DisplayClustering {
-
- //static List<List<Cluster>> result;
-
+
+ // static List<List<Cluster>> result;
+
DisplayKMeans() {
initialize();
- this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of
population)");
+ this.setTitle("k-Means Clusters (>" + (int) (significance * 100)
+ + "% of population)");
}
-
+
public static void main(String[] args) throws Exception {
DistanceMeasure measure = new ManhattanDistanceMeasure();
Path samples = new Path("samples");
@@ -45,55 +55,45 @@ class DisplayKMeans extends DisplayClust
Configuration conf = new Configuration();
HadoopUtil.delete(conf, samples);
HadoopUtil.delete(conf, output);
-
+
RandomUtils.useTestSeed();
DisplayClustering.generateSamples();
writeSampleData(samples);
- //boolean b = true;
+ boolean b = false;
int maxIter = 10;
double distanceThreshold = 0.001;
- //if (b) {
- Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new
Path(output, "clusters-0"), 3, measure);
- KMeansDriver.run(samples,
- clusters,
- output,
- measure,
- distanceThreshold,
- maxIter,
- true,
- true);
- loadClusters(output);
- //} else {
- // List<Vector> points = new ArrayList<Vector>();
- // for (VectorWritable sample : SAMPLE_DATA) {
- // points.add(sample.get());
- // }
- // List<Cluster> initialClusters = new ArrayList<Cluster>();
- // int id = 0;
- // int numClusters = 3;
- // for (Vector point : points) {
- // if (initialClusters.size() < Math.min(numClusters, points.size())) {
- // initialClusters.add(new Cluster(point, id++));
- // } else {
- // break;
- // }
- // }
- //
- // result = KMeansClusterer.clusterPoints(points, initialClusters,
measure, maxIter, distanceThreshold);
- // for (List<Cluster> models : result) {
- // List<org.apache.mahout.clustering.Cluster> clusters = new
ArrayList<org.apache.mahout.clustering.Cluster>();
- // for (AbstractCluster cluster : models) {
- // org.apache.mahout.clustering.Cluster cluster2 =
(org.apache.mahout.clustering.Cluster) cluster;
- // if (isSignificant(cluster2)) {
- // clusters.add(cluster2);
- // }
- // }
- // CLUSTERS.add(clusters);
- // }
- //}
+ if (b) {
+ Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(
+ output, "clusters-0"), 3, measure);
+ KMeansDriver.run(samples, clusters, output, measure, distanceThreshold,
+ maxIter, true, true);
+ loadClusters(output);
+ } else {
+ List<Vector> points = new ArrayList<Vector>();
+ for (VectorWritable sample : SAMPLE_DATA) {
+ points.add(sample.get());
+ }
+ List<Cluster> initialClusters = new ArrayList<Cluster>();
+ int id = 0;
+ int numClusters = 3;
+ for (Vector point : points) {
+ if (initialClusters.size() < Math.min(numClusters, points.size())) {
+ initialClusters.add(new org.apache.mahout.clustering.kmeans.Cluster(
+ point, id++, measure));
+ } else {
+ break;
+ }
+ }
+
+ ClusterClassifier prior = new ClusterClassifier(initialClusters);
+ ClusteringPolicy policy = new KMeansClusteringPolicy();
+ ClusterClassifier posterior = new ClusterIterator(policy).iterate(points,
+ prior, 10);
+ CLUSTERS.add(posterior.getModels());
+ }
new DisplayKMeans();
}
-
+
// Override the paint() method
@Override
public void paint(Graphics g) {
Modified:
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java?rev=1094222&r1=1094221&r2=1094222&view=diff
==============================================================================
---
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
(original)
+++
mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java
Mon Apr 18 04:19:01 2011
@@ -178,7 +178,7 @@ public final class TestL1ModelClustering
private void printClusters(Model<VectorWritable>[] models,
List<VectorWritable> samples, String[] docs) {
for (int m = 0; m < models.length; m++) {
Model<VectorWritable> model = models[m];
- int count = model.count();
+ long count = model.count();
if (count == 0) {
continue;
}