Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java Wed May 9 22:02:50 2012 @@ -37,7 +37,6 @@ import org.apache.hadoop.conf.Configurat import org.apache.hadoop.fs.FileStatus; import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; -import org.apache.hadoop.fs.PathFilter; import org.apache.hadoop.io.IntWritable; import org.apache.hadoop.io.SequenceFile; import org.apache.hadoop.io.Text; @@ -329,20 +328,6 @@ public class DisplayClustering extends F } } - protected static List<Cluster> readClusters(Path clustersIn) { - List<Cluster> clusters = Lists.newArrayList(); - Configuration conf = new Configuration(); - for (Cluster value : new SequenceFileDirValueIterable<Cluster>(clustersIn, PathType.LIST, - PathFilters.logsCRCFilter(), conf)) { - log.info( - "Reading Cluster:{} center:{} numPoints:{} radius:{}", - new Object[] {value.getId(), AbstractCluster.formatVector(value.getCenter(), null), - value.getNumObservations(), AbstractCluster.formatVector(value.getRadius(), null)}); - clusters.add(value); - } - return clusters; - } - protected static List<Cluster> readClustersWritable(Path clustersIn) { List<Cluster> clusters = Lists.newArrayList(); Configuration conf = new Configuration(); @@ -358,15 +343,6 @@ public class DisplayClustering extends F return clusters; } - protected static void loadClusters(Path output) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(output.toUri(), conf); - for (FileStatus s : fs.listStatus(output, new ClustersFilter())) { - List<Cluster> clusters = readClusters(s.getPath()); - CLUSTERS.add(clusters); - } - } - protected static void loadClustersWritable(Path output) throws IOException { Configuration conf = new Configuration(); FileSystem fs = FileSystem.get(output.toUri(), conf); @@ -376,15 +352,6 @@ public class DisplayClustering extends F } } - protected static void loadClusters(Path output, PathFilter filter) throws IOException { - Configuration conf = new Configuration(); - FileSystem fs = FileSystem.get(output.toUri(), conf); - for (FileStatus s : fs.listStatus(output, filter)) { - List<Cluster> clusters = readClusters(s.getPath()); - CLUSTERS.add(clusters); - } - } - /** * Generate random samples and add them to the sampleData *
Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java Wed May 9 22:02:50 2012 @@ -22,28 +22,28 @@ import java.awt.Graphics2D; import java.io.IOException; import java.util.List; +import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.Model; import org.apache.mahout.clustering.ModelDistribution; import org.apache.mahout.clustering.classify.ClusterClassifier; -import org.apache.mahout.clustering.dirichlet.DirichletClusterer; +import org.apache.mahout.clustering.dirichlet.DirichletDriver; +import org.apache.mahout.clustering.dirichlet.models.DistributionDescription; import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution; import org.apache.mahout.clustering.iterator.ClusterIterator; -import org.apache.mahout.clustering.iterator.ClusteringPolicy; import org.apache.mahout.clustering.iterator.DirichletClusteringPolicy; +import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.VectorWritable; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; import com.google.common.collect.Lists; public class DisplayDirichlet extends DisplayClustering { - private static final Logger log = LoggerFactory.getLogger(DisplayDirichlet.class); - public DisplayDirichlet() { initialize(); this.setTitle("Dirichlet Process Clusters - Normal Distribution (>" + (int) (significance * 100) @@ -57,50 +57,19 @@ public class DisplayDirichlet extends Di plotClusters((Graphics2D) g); } - protected static void printModels(Iterable<Cluster[]> result, int significant) { - int row = 0; - StringBuilder models = new StringBuilder(100); - for (Cluster[] r : result) { - models.append("sample[").append(row++).append("]= "); - for (int k = 0; k < r.length; k++) { - Cluster model = r[k]; - if (model.getNumObservations() > significant) { - models.append('m').append(k).append(model.asFormatString(null)).append(", "); - } - } - models.append('\n'); - } - models.append('\n'); - log.info(models.toString()); - } - - protected static void generateResults(ModelDistribution<VectorWritable> modelDist, int numClusters, - int numIterations, double alpha0, int thin, int burnin) throws IOException { - boolean runClusterer = false; + protected static void generateResults(Path input, Path output, + ModelDistribution<VectorWritable> modelDist, int numClusters, int numIterations, double alpha0, int thin, int burnin) throws IOException, ClassNotFoundException, + InterruptedException { + boolean runClusterer = true; if (runClusterer) { - runSequentialDirichletClusterer(modelDist, numClusters, numIterations, alpha0, thin, burnin); + runSequentialDirichletClusterer(input, output, modelDist, numClusters, numIterations, alpha0); } else { - runSequentialDirichletClassifier(modelDist, numClusters, numIterations, alpha0); + runSequentialDirichletClassifier(input, output, modelDist, numClusters, numIterations, alpha0); } - } - - private static void runSequentialDirichletClassifier(ModelDistribution<VectorWritable> modelDist, int numClusters, - int numIterations, double alpha0) throws IOException { - List<Cluster> models = Lists.newArrayList(); - for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(numClusters)) { - models.add((Cluster) cluster); - } - ClusterClassifier prior = new ClusterClassifier(models, new DirichletClusteringPolicy(numClusters, alpha0)); - Path samples = new Path("samples"); - Path output = new Path("output"); - Path priorPath = new Path(output, "clusters-0"); - prior.writeToSeqFiles(priorPath); - - new ClusterIterator().iterateSeq(samples, priorPath, output, numIterations); for (int i = 1; i <= numIterations; i++) { ClusterClassifier posterior = new ClusterClassifier(); String name = i == numIterations ? "clusters-" + i + "-final" : "clusters-" + i; - posterior.readFromSeqFiles(new Path(output, name)); + posterior.readFromSeqFiles(new Configuration(), new Path(output, name)); List<Cluster> clusters = Lists.newArrayList(); for (Cluster cluster : posterior.getModels()) { if (isSignificant(cluster)) { @@ -111,33 +80,47 @@ public class DisplayDirichlet extends Di } } - private static void runSequentialDirichletClusterer(ModelDistribution<VectorWritable> modelDist, int numClusters, - int numIterations, double alpha0, int thin, int burnin) { - DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist, alpha0, numClusters, thin, burnin); - List<Cluster[]> result = dc.cluster(numIterations); - printModels(result, burnin); - for (Cluster[] models : result) { - List<Cluster> clusters = Lists.newArrayList(); - for (Cluster cluster : models) { - if (isSignificant(cluster)) { - clusters.add(cluster); - } - } - CLUSTERS.add(clusters); + private static void runSequentialDirichletClassifier(Path input, Path output, + ModelDistribution<VectorWritable> modelDist, int numClusters, int numIterations, double alpha0) + throws IOException { + List<Cluster> models = Lists.newArrayList(); + for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(numClusters)) { + models.add((Cluster) cluster); } + ClusterClassifier prior = new ClusterClassifier(models, new DirichletClusteringPolicy(numClusters, alpha0)); + Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR); + prior.writeToSeqFiles(priorPath); + Configuration conf = new Configuration(); + new ClusterIterator().iterateSeq(conf, input, priorPath, output, numIterations); + } + + private static void runSequentialDirichletClusterer(Path input, Path output, + ModelDistribution<VectorWritable> modelDist, int numClusters, int numIterations, double alpha0) + throws IOException, ClassNotFoundException, InterruptedException { + DistributionDescription description = new DistributionDescription(modelDist.getClass().getName(), + RandomAccessSparseVector.class.getName(), ManhattanDistanceMeasure.class.getName(), 2); + + DirichletDriver.run(new Configuration(), input, output, description, numClusters, numIterations, alpha0, true, + true, 0, false); } public static void main(String[] args) throws Exception { VectorWritable modelPrototype = new VectorWritable(new DenseVector(2)); ModelDistribution<VectorWritable> modelDist = new GaussianClusterDistribution(modelPrototype); + Configuration conf = new Configuration(); + Path output = new Path("output"); + HadoopUtil.delete(conf, output); + Path samples = new Path("samples"); + HadoopUtil.delete(conf, samples); RandomUtils.useTestSeed(); generateSamples(); + writeSampleData(samples); int numIterations = 20; int numClusters = 10; int alpha0 = 1; int thin = 3; int burnin = 5; - generateResults(modelDist, numClusters, numIterations, alpha0, thin, burnin); + generateResults(samples, output, modelDist, numClusters, numIterations, alpha0, thin, burnin); new DisplayDirichlet(); } Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java Wed May 9 22:02:50 2012 @@ -60,12 +60,12 @@ public class DisplayFuzzyKMeans extends Path samples = new Path("samples"); Path output = new Path("output"); Configuration conf = new Configuration(); - HadoopUtil.delete(conf, samples); HadoopUtil.delete(conf, output); + HadoopUtil.delete(conf, samples); RandomUtils.useTestSeed(); DisplayClustering.generateSamples(); writeSampleData(samples); - boolean runClusterer = false; + boolean runClusterer = true; int maxIterations = 10; float threshold = 0.001F; float m = 1.1F; @@ -93,16 +93,17 @@ public class DisplayFuzzyKMeans extends Path priorPath = new Path(output, "classifier-0"); prior.writeToSeqFiles(priorPath); - new ClusterIterator().iterateSeq(samples, priorPath, output, maxIterations); + new ClusterIterator().iterateSeq(conf, samples, priorPath, output, maxIterations); loadClustersWritable(output); } private static void runSequentialFuzzyKClusterer(Configuration conf, Path samples, Path output, DistanceMeasure measure, int maxIterations, float m, double threshold) throws IOException, ClassNotFoundException, InterruptedException { - Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(output, "clusters-0"), 3, measure); - FuzzyKMeansDriver.run(samples, clusters, output, measure, threshold, maxIterations, m, true, true, threshold, true); + Path clustersIn = new Path(output, "random-seeds"); + RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure); + FuzzyKMeansDriver.run(samples, clustersIn, output, measure, threshold, maxIterations, m, true, true, threshold, true); - loadClusters(output); + loadClustersWritable(output); } } Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java Wed May 9 22:02:50 2012 @@ -55,9 +55,9 @@ public class DisplayKMeans extends Displ HadoopUtil.delete(conf, output); RandomUtils.useTestSeed(); - DisplayClustering.generateSamples(); + generateSamples(); writeSampleData(samples); - boolean runClusterer = false; + boolean runClusterer = true; double convergenceDelta = 0.001; if (runClusterer) { int numClusters = 3; @@ -81,20 +81,21 @@ public class DisplayKMeans extends Displ initialClusters.add(new org.apache.mahout.clustering.kmeans.Kluster(point, id++, measure)); } ClusterClassifier prior = new ClusterClassifier(initialClusters, new KMeansClusteringPolicy(convergenceDelta)); - Path priorPath = new Path(output, "clusters-0"); + Path priorPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR); prior.writeToSeqFiles(priorPath); int maxIter = 10; - new ClusterIterator().iterateSeq(samples, priorPath, output, maxIter); + new ClusterIterator().iterateSeq(conf, samples, priorPath, output, maxIter); loadClustersWritable(output); } private static void runSequentialKMeansClusterer(Configuration conf, Path samples, Path output, DistanceMeasure measure, int maxIterations, double convergenceDelta) throws IOException, InterruptedException, ClassNotFoundException { - Path clusters = RandomSeedGenerator.buildRandom(conf, samples, new Path(output, "clusters-0"), 3, measure); - KMeansDriver.run(samples, clusters, output, measure, convergenceDelta, maxIterations, true, 0.0, true); - loadClusters(output); + Path clustersIn = new Path(output, "random-seeds"); + RandomSeedGenerator.buildRandom(conf, samples, clustersIn, 3, measure); + KMeansDriver.run(samples, clustersIn, output, measure, convergenceDelta, maxIterations, true, 0.0, true); + loadClustersWritable(output); } // Override the paint() method Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java Wed May 9 22:02:50 2012 @@ -111,7 +111,7 @@ public class DisplayMeanShift extends Di // if (b) { MeanShiftCanopyDriver.run(conf, samples, output, measure, kernelProfile, t1, t2, 0.005, 20, false, true, true); - loadClusters(output); + loadClustersWritable(output); // } else { // Collection<Vector> points = new ArrayList<Vector>(); // for (VectorWritable sample : SAMPLE_DATA) { Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java Wed May 9 22:02:50 2012 @@ -139,16 +139,8 @@ public final class Job extends AbstractJ throws Exception{ Path directoryContainingConvertedInput = new Path(output, DIRECTORY_CONTAINING_CONVERTED_INPUT); InputDriver.runJob(input, directoryContainingConvertedInput, "org.apache.mahout.math.RandomAccessSparseVector"); - DirichletDriver.run(directoryContainingConvertedInput, - output, - description, - numModels, - maxIterations, - alpha0, - true, - emitMostLikely, - threshold, - false); + DirichletDriver.run(new Configuration(), directoryContainingConvertedInput, output, description, numModels, maxIterations, alpha0, true, + emitMostLikely, threshold, false); // run ClusterDumper ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-" + maxIterations), new Path(output, "clusteredPoints")); Modified: mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java (original) +++ mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/TestClusterEvaluator.java Wed May 9 22:02:50 2012 @@ -412,8 +412,8 @@ public final class TestClusterEvaluator DistributionDescription description = new DistributionDescription( GaussianClusterDistribution.class.getName(), DenseVector.class.getName(), null, 2); - DirichletDriver.run(testdata, output, description, 15, 5, 1.0, true, true, - 0, true); + DirichletDriver.run(new Configuration(), testdata, output, description, 15, 5, 1.0, true, + true, (double) 0, true); int numIterations = 10; Configuration conf = new Configuration(); Path clustersIn = new Path(output, "clusters-5-final"); Modified: mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java (original) +++ mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/cdbw/TestCDbwEvaluator.java Wed May 9 22:02:50 2012 @@ -427,8 +427,8 @@ public final class TestCDbwEvaluator ext DistributionDescription description = new DistributionDescription( GaussianClusterDistribution.class.getName(), DenseVector.class.getName(), null, 2); - DirichletDriver.run(testdata, output, description, 15, 5, 1.0, true, true, - 0, true); + DirichletDriver.run(new Configuration(), testdata, output, description, 15, 5, 1.0, true, + true, (double) 0, true); int numIterations = 10; Path clustersIn = new Path(output, "clusters-0"); RepresentativePointsDriver.run(conf, clustersIn, new Path(output, Modified: mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java URL: http://svn.apache.org/viewvc/mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java?rev=1336424&r1=1336423&r2=1336424&view=diff ============================================================================== --- mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java (original) +++ mahout/trunk/integration/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java Wed May 9 22:02:50 2012 @@ -23,8 +23,7 @@ import java.util.Iterator; import java.util.List; import java.util.Locale; -import com.google.common.collect.Lists; -import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; import org.apache.lucene.analysis.standard.StandardAnalyzer; import org.apache.lucene.document.Document; import org.apache.lucene.document.Field; @@ -35,8 +34,16 @@ import org.apache.lucene.store.RAMDirect import org.apache.lucene.util.Version; import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.Model; +import org.apache.mahout.clustering.ModelDistribution; +import org.apache.mahout.clustering.classify.ClusterClassifier; import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution; +import org.apache.mahout.clustering.dirichlet.models.DistributionDescription; import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution; +import org.apache.mahout.clustering.iterator.ClusterIterator; +import org.apache.mahout.clustering.iterator.DirichletClusteringPolicy; +import org.apache.mahout.common.distance.CosineDistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.math.RandomAccessSparseVector; import org.apache.mahout.math.Vector; import org.apache.mahout.math.VectorWritable; import org.apache.mahout.utils.MahoutTestCase; @@ -49,6 +56,9 @@ import org.apache.mahout.vectorizer.TFID import org.apache.mahout.vectorizer.Weight; import org.junit.Test; +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; + public final class TestL1ModelClustering extends MahoutTestCase { private class MapElement implements Comparable<MapElement> { @@ -118,7 +128,7 @@ public final class TestL1ModelClustering "The robber wore a white fleece jacket and a baseball cap.", "The English Springer Spaniel is the best of all dogs."}; - private List<VectorWritable> sampleData; + private List<Vector> sampleData; private void getSampleData(String[] docs2) throws IOException { sampleData = Lists.newArrayList(); @@ -148,11 +158,11 @@ public final class TestL1ModelClustering for (Vector vector : iterable) { assertNotNull(vector); System.out.println("Vector[" + i++ + "]=" + formatVector(vector)); - sampleData.add(new VectorWritable(vector)); + sampleData.add(vector); } } - private static String formatVector(Vector v) { + private String formatVector(Vector v) { StringBuilder buf = new StringBuilder(); int nzero = 0; Iterator<Vector.Element> iterateNonZero = v.iterateNonZero(); @@ -179,7 +189,7 @@ public final class TestL1ModelClustering return buf.toString(); } - private static void printSamples(Iterable<Cluster[]> result, int significant) { + private void printSamples(Iterable<Cluster[]> result, int significant) { int row = 0; for (Cluster[] r : result) { int sig = 0; @@ -199,19 +209,19 @@ public final class TestL1ModelClustering System.out.println(); } - private void printClusters(Model<VectorWritable>[] models, List<VectorWritable> samples, String[] docs) { - for (int m = 0; m < models.length; m++) { - Model<VectorWritable> model = models[m]; + private void printClusters(List<Cluster> models, String[] docs) { + for (int m = 0; m < models.size(); m++) { + Cluster model = models.get(m); long count = model.getNumObservations(); if (count == 0) { continue; } - System.out.println("Model[" + m + "] had " + count + " hits (!) and " + (samples.size() - count) + System.out.println("Model[" + m + "] had " + count + " hits (!) and " + (sampleData.size() - count) + " misses (? in pdf order) during the last iteration:"); - MapElement[] map = new MapElement[samples.size()]; + MapElement[] map = new MapElement[sampleData.size()]; // sort the samples by pdf - for (int i = 0; i < samples.size(); i++) { - VectorWritable sample = samples.get(i); + for (int i = 0; i < sampleData.size(); i++) { + VectorWritable sample = new VectorWritable(sampleData.get(i)); map[i] = new MapElement(model.pdf(sample), docs[i]); } Arrays.sort(map); @@ -230,45 +240,81 @@ public final class TestL1ModelClustering @Test public void testDocs() throws Exception { getSampleData(DOCS); - DirichletClusterer dc = new DirichletClusterer(sampleData, new GaussianClusterDistribution(sampleData.get(0)), 1.0, - 15, 1, 0); - List<Cluster[]> result = dc.cluster(10); - assertNotNull(result); - printSamples(result, 0); - printClusters(result.get(result.size() - 1), sampleData, DOCS); + DistributionDescription description = new DistributionDescription(GaussianClusterDistribution.class.getName(), + RandomAccessSparseVector.class.getName(), ManhattanDistanceMeasure.class.getName(), sampleData.get(0).size()); + + List<Cluster> models = Lists.newArrayList(); + ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new Configuration()); + for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) { + models.add((Cluster) cluster); + } + + ClusterIterator iterator = new ClusterIterator(); + ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15, 1.0)); + ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10); + + printClusters(posterior.getModels(), DOCS); } @Test public void testDMDocs() throws Exception { + getSampleData(DOCS); - DirichletClusterer dc = new DirichletClusterer(sampleData, - new DistanceMeasureClusterDistribution(sampleData.get(0)), 1.0, 15, 1, 0); - List<Cluster[]> result = dc.cluster(10); - assertNotNull(result); - printSamples(result, 0); - printClusters(result.get(result.size() - 1), sampleData, DOCS); + DistributionDescription description = new DistributionDescription( + DistanceMeasureClusterDistribution.class.getName(), RandomAccessSparseVector.class.getName(), + CosineDistanceMeasure.class.getName(), sampleData.get(0).size()); + + List<Cluster> models = Lists.newArrayList(); + ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new Configuration()); + for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) { + models.add((Cluster) cluster); + } + + ClusterIterator iterator = new ClusterIterator(); + ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15, 1.0)); + ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10); + + printClusters(posterior.getModels(), DOCS); } @Test public void testDocs2() throws Exception { getSampleData(DOCS2); - DirichletClusterer dc = new DirichletClusterer(sampleData, new GaussianClusterDistribution(sampleData.get(0)), 1.0, - 15, 1, 0); - List<Cluster[]> result = dc.cluster(10); - assertNotNull(result); - printSamples(result, 0); - printClusters(result.get(result.size() - 1), sampleData, DOCS2); + DistributionDescription description = new DistributionDescription(GaussianClusterDistribution.class.getName(), + RandomAccessSparseVector.class.getName(), ManhattanDistanceMeasure.class.getName(), sampleData.get(0).size()); + + List<Cluster> models = Lists.newArrayList(); + ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new Configuration()); + for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) { + models.add((Cluster) cluster); + } + + ClusterIterator iterator = new ClusterIterator(); + ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15, 1.0)); + ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10); + + printClusters(posterior.getModels(), DOCS2); } @Test public void testDMDocs2() throws Exception { - getSampleData(DOCS2); - DirichletClusterer dc = new DirichletClusterer(sampleData, - new DistanceMeasureClusterDistribution(sampleData.get(0)), 1.0, 15, 1, 0); - List<Cluster[]> result = dc.cluster(10); - assertNotNull(result); - printSamples(result, 0); - printClusters(result.get(result.size() - 1), sampleData, DOCS2); + + getSampleData(DOCS); + DistributionDescription description = new DistributionDescription( + DistanceMeasureClusterDistribution.class.getName(), RandomAccessSparseVector.class.getName(), + CosineDistanceMeasure.class.getName(), sampleData.get(0).size()); + + List<Cluster> models = Lists.newArrayList(); + ModelDistribution<VectorWritable> modelDist = description.createModelDistribution(new Configuration()); + for (Model<VectorWritable> cluster : modelDist.sampleFromPrior(15)) { + models.add((Cluster) cluster); + } + + ClusterIterator iterator = new ClusterIterator(); + ClusterClassifier classifier = new ClusterClassifier(models, new DirichletClusteringPolicy(15, 1.0)); + ClusterClassifier posterior = iterator.iterate(sampleData, classifier, 10); + + printClusters(posterior.getModels(), DOCS2); } }
