Author: jeastman Date: Wed Apr 15 22:38:44 2009 New Revision: 765403 URL: http://svn.apache.org/viewvc?rev=765403&view=rev Log: Added examples of MeanShift and Fuzzy K-Means operating on Dirichlet sample data
Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java Modified: lucene/mahout/trunk/core/ (props changed) lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java lucene/mahout/trunk/examples/ (props changed) lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java Propchange: lucene/mahout/trunk/core/ ------------------------------------------------------------------------------ --- svn:ignore (original) +++ svn:ignore Wed Apr 15 22:38:44 2009 @@ -6,3 +6,6 @@ test target *.iml +.settings +.classpath +.project Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java?rev=765403&r1=765402&r2=765403&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java Wed Apr 15 22:38:44 2009 @@ -26,6 +26,7 @@ import org.apache.hadoop.mapred.OutputCollector; import org.apache.mahout.matrix.AbstractVector; import org.apache.mahout.matrix.SparseVector; +import org.apache.mahout.matrix.SquareRootFunction; import org.apache.mahout.matrix.Vector; import org.apache.mahout.utils.DistanceMeasure; @@ -42,7 +43,8 @@ private static double m = 2.0; // default value public static final double MINIMAL_VALUE = 0.0000000001; // using it for - // adding + + // adding // exception // this value to any @@ -70,6 +72,13 @@ // has the centroid converged with the center? private boolean converged = false; + // track membership parameters + double s0 = 0; + + Vector s1; + + Vector s2; + private static DistanceMeasure measure; private static double convergenceDelta = 0; @@ -163,9 +172,9 @@ double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList); Text key = new Text(clusters.get(i).getIdentifier()); // just output the - // identifier,avoids - // too much data - // traffic + // identifier,avoids + // too much data + // traffic Text value = new Text(Double.toString(probWeight) + FuzzyKMeansDriver.MAPPER_VALUE_SEPARATOR + values.toString()); output.collect(key, value); @@ -203,8 +212,7 @@ probWeight).append(' '); } output.collect(new Text(outputKey.trim()), new Text(outputValue.toString() - .trim() - + ']')); + .trim() + ']')); } /** @@ -295,12 +303,47 @@ } /** + * Observe the point, accumulating weighted variables for std() calculation + * @param point + * @param ptProb + */ + private void observePoint(Vector point, double ptProb) { + s0 += ptProb; + Vector wtPt = point.times(ptProb); + if (s1 == null) + s1 = point.copy(); + else + s1 = s1.plus(wtPt); + if (s2 == null) + s2 = wtPt.times(wtPt); + else + s2 = s2.plus(wtPt.times(wtPt)); + } + + /** + * Compute a "standard deviation" value to use as the "radius" of the cluster for display purposes + * @return + */ + public double std() { + if (s0 > 0) { + Vector radical = s2.times(s0).minus(s1.times(s1)); + radical = radical.times(radical).assign(new SquareRootFunction()); + Vector stds = radical.assign(new SquareRootFunction()).divide(s0); + double res = stds.zSum() / stds.cardinality(); + System.out.println(res); + return res; + } else + return 0.33; + } + + /** * Add the point to the SoftCluster * * @param point a point to add * @param ptProb */ public void addPoint(Vector point, double ptProb) { + observePoint(point, ptProb); centroid = null; pointProbSum += ptProb; if (weightedPointTotal == null) Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java?rev=765403&r1=765402&r2=765403&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/meanshift/MeanShiftCanopy.java Wed Apr 15 22:38:44 2009 @@ -57,10 +57,10 @@ private static int nextCanopyId = 0; // the T1 distance threshold - private static double t1; + static double t1; // the T2 distance threshold - private static double t2; + static double t2; // the distance measure private static DistanceMeasure measure; Propchange: lucene/mahout/trunk/examples/ ------------------------------------------------------------------------------ --- svn:ignore (original) +++ svn:ignore Wed Apr 15 22:38:44 2009 @@ -6,3 +6,6 @@ temp work *.iml +.settings +.classpath +.project Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java?rev=765403&r1=765402&r2=765403&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java Wed Apr 15 22:38:44 2009 @@ -22,11 +22,11 @@ public class DisplayDirichlet extends Frame { private static final long serialVersionUID = 1L; - int res; //screen resolution + protected int res; //screen resolution - int ds = 72; //default scale = 72 pixels per inch + protected int ds = 72; //default scale = 72 pixels per inch - int size = 8; // screen size in inches + protected int size = 8; // screen size in inches public static List<Vector> sampleData = new ArrayList<Vector>(); Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java?rev=765403&view=auto ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java (added) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/fuzzykmeans/DisplayFuzzyKMeans.java Wed Apr 15 22:38:44 2009 @@ -0,0 +1,188 @@ +package org.apache.mahout.clustering.fuzzykmeans; + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.awt.BasicStroke; +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.mahout.clustering.canopy.Canopy; +import org.apache.mahout.clustering.dirichlet.DisplayDirichlet; +import org.apache.mahout.clustering.dirichlet.UncommonDistributions; +import org.apache.mahout.clustering.kmeans.Cluster; +import org.apache.mahout.matrix.DenseVector; +import org.apache.mahout.matrix.Vector; +import org.apache.mahout.utils.DistanceMeasure; +import org.apache.mahout.utils.ManhattanDistanceMeasure; + +class DisplayFuzzyKMeans extends DisplayDirichlet { + public DisplayFuzzyKMeans() { + initialize(); + this.setTitle("Fuzzy K-Means Clusters (> 5% of population)"); + } + + private static final long serialVersionUID = 1L; + + static List<Canopy> canopies; + + static List<List<SoftCluster>> clusters; + + static double t1 = 3.0; + + static double t2 = 1.5; + + public void paint(Graphics g) { + super.plotSampleData(g); + Graphics2D g2 = (Graphics2D) g; + Vector dv = new DenseVector(2); + int i = clusters.size() - 1; + for (List<SoftCluster> cls : clusters) { + g2.setStroke(new BasicStroke(i == 0 ? 3 : 1)); + g2.setColor(colors[Math.min(colors.length - 1, i--)]); + for (SoftCluster cluster : cls) + if (true || cluster.getWeightedPointTotal().zSum() > sampleData.size() * 0.05) { + dv.assign(cluster.std() * 3); + plotEllipse(g2, cluster.getCenter(), dv); + } + } + } + + public static void referenceFuzzyKMeans(List<Vector> points, + DistanceMeasure measure, double threshold, int numIter) throws Exception { + SoftCluster.config(measure, threshold); + boolean converged = false; + int iteration = 0; + for (int iter = 0; !converged && iter < numIter; iter++) { + List<SoftCluster> next = new ArrayList<SoftCluster>(); + List<SoftCluster> cs = clusters.get(iteration++); + for (SoftCluster c : cs) + next.add(new SoftCluster(c.getCenter())); + clusters.add(next); + converged = iterateReference(points, clusters.get(iteration), measure); + } + } + + /** + * Perform a single iteration over the points and clusters, assigning points + * to clusters and returning if the iterations are completed. + * + * @param points the List<Vector> having the input points + * @param clusters the List<Cluster> clusters + * @param measure a DistanceMeasure to use + * @return + */ + public static boolean iterateReference(List<Vector> points, + List<SoftCluster> clusterList, DistanceMeasure measure) { + // for each + for (Vector point : points) { + List<Double> clusterDistanceList = new ArrayList<Double>(); + for (SoftCluster cluster : clusterList) { + clusterDistanceList.add(measure.distance(point, cluster.getCenter())); + } + + for (int i = 0; i < clusterList.size(); i++) { + double probWeight = SoftCluster.computeProbWeight(clusterDistanceList + .get(i), clusterDistanceList); + clusterList.get(i).addPoint(point, + Math.pow(probWeight, SoftCluster.getM())); + } + } + boolean converged = true; + for (SoftCluster cluster : clusterList) { + if (!cluster.computeConvergence()) + converged = false; + } + // update the cluster centers + if (!converged) + for (SoftCluster cluster : clusterList) + cluster.recomputeCenter(); + return converged; + + } + + /** + * Iterate through the points, adding new canopies. Return the canopies. + * + * @param measure + * a DistanceMeasure to use + * @param points + * a list<Vector> defining the points to be clustered + * @param t1 + * the T1 distance threshold + * @param t2 + * the T2 distance threshold + * @return the List<Canopy> created + */ + static List<Canopy> populateCanopies(DistanceMeasure measure, + List<Vector> points, double t1, double t2) { + List<Canopy> canopies = new ArrayList<Canopy>(); + Canopy.config(measure, t1, t2); + /** + * Reference Implementation: Given a distance metric, one can create + * canopies as follows: Start with a list of the data points in any order, + * and with two distance thresholds, T1 and T2, where T1 > T2. (These + * thresholds can be set by the user, or selected by cross-validation.) Pick + * a point on the list and measure its distance to all other points. Put all + * points that are within distance threshold T1 into a canopy. Remove from + * the list all points that are within distance threshold T2. Repeat until + * the list is empty. + */ + while (!points.isEmpty()) { + Iterator<Vector> ptIter = points.iterator(); + Vector p1 = ptIter.next(); + ptIter.remove(); + Canopy canopy = new Canopy(p1); + canopies.add(canopy); + while (ptIter.hasNext()) { + Vector p2 = ptIter.next(); + double dist = measure.distance(p1, p2); + // Put all points that are within distance threshold T1 into the canopy + if (dist < t1) + canopy.addPoint(p2); + // Remove from the list all points that are within distance threshold T2 + if (dist < t2) + ptIter.remove(); + } + } + return canopies; + } + + public static void main(String[] args) { + UncommonDistributions.init("Mahout=Hadoop+ML".getBytes()); + generateSamples(); + List<Vector> points = new ArrayList<Vector>(); + points.addAll(sampleData); + canopies = populateCanopies(new ManhattanDistanceMeasure(), points, t1, t2); + DistanceMeasure measure = new ManhattanDistanceMeasure(); + Cluster.config(measure, 0.001); + clusters = new ArrayList<List<SoftCluster>>(); + clusters.add(new ArrayList<SoftCluster>()); + for (Canopy canopy : canopies) + if (canopy.getNumPoints() > 0.05 * sampleData.size()) + clusters.get(0).add(new SoftCluster(canopy.getCenter())); + try { + referenceFuzzyKMeans(sampleData, measure, 0.001, 10); + } catch (Exception e) { + e.printStackTrace(); + } + new DisplayFuzzyKMeans(); + } +} Added: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java?rev=765403&view=auto ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java (added) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/meanshift/DisplayMeanShift.java Wed Apr 15 22:38:44 2009 @@ -0,0 +1,108 @@ +package org.apache.mahout.clustering.meanshift; + +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.awt.Color; +import java.awt.Graphics; +import java.awt.Graphics2D; +import java.awt.geom.AffineTransform; +import java.util.ArrayList; +import java.util.List; + +import org.apache.mahout.clustering.dirichlet.DisplayDirichlet; +import org.apache.mahout.clustering.dirichlet.UncommonDistributions; +import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution; +import org.apache.mahout.matrix.DenseVector; +import org.apache.mahout.matrix.Vector; +import org.apache.mahout.utils.EuclideanDistanceMeasure; + +class DisplayMeanShift extends DisplayDirichlet { + public DisplayMeanShift() { + initialize(); + this.setTitle("Canopy Clusters (> 1.5% of population)"); + } + + private static final long serialVersionUID = 1L; + + private static List<MeanShiftCanopy> canopies = new ArrayList<MeanShiftCanopy>(); + + private static List<List<Vector>> iterationCenters = new ArrayList<List<Vector>>(); + + public void paint(Graphics g) { + Graphics2D g2 = (Graphics2D) g; + double sx = (double) res / ds; + g2.setTransform(AffineTransform.getScaleInstance(sx, sx)); + + // plot the axes + g2.setColor(Color.BLACK); + Vector dv = new DenseVector(2).assign(size / 2); + Vector dv1 = new DenseVector(2).assign(MeanShiftCanopy.t1); + Vector dv2 = new DenseVector(2).assign(MeanShiftCanopy.t2); + plotRectangle(g2, new DenseVector(2).assign(2), dv); + plotRectangle(g2, new DenseVector(2).assign(-2), dv); + + // plot the sample data + g2.setColor(Color.DARK_GRAY); + dv.assign(0.03); + for (Vector v : sampleData) + plotRectangle(g2, v, dv); + int i = 0; + for (MeanShiftCanopy canopy : canopies) + if (canopy.getBoundPoints().size() > 0.015 * sampleData.size()) { + g2.setColor(colors[Math.min(i++, colors.length - 1)]); + for (Vector v : canopy.getBoundPoints()) + plotRectangle(g2, v, dv); + plotEllipse(g2, canopy.getCenter(), dv1); + plotEllipse(g2, canopy.getCenter(), dv2); + } + } + + public static void testReferenceImplementation() { + MeanShiftCanopy.config(new EuclideanDistanceMeasure(), 1.0, 0.05, 0.5); + // add all points to the canopies + for (Vector aRaw : sampleData) { + MeanShiftCanopy.mergeCanopy(new MeanShiftCanopy(aRaw), canopies); + } + boolean done = false; + while (!done) {// shift canopies to their centroids + done = true; + List<MeanShiftCanopy> migratedCanopies = new ArrayList<MeanShiftCanopy>(); + List<Vector> centers = new ArrayList<Vector>(); + for (MeanShiftCanopy canopy : canopies) { + centers.add(canopy.getCenter()); + done = canopy.shiftToMean() && done; + MeanShiftCanopy.mergeCanopy(canopy, migratedCanopies); + } + iterationCenters.add(centers); + canopies = migratedCanopies; + } + } + + public static void main(String[] args) { + UncommonDistributions.init("Mahout=Hadoop+ML".getBytes()); + generateSamples(); + testReferenceImplementation(); + for (MeanShiftCanopy canopy : canopies) + System.out.println(canopy.toString()); + new DisplayMeanShift(); + } + + static void generateResults() { + DisplayDirichlet.generateResults(new NormalModelDistribution()); + } +}