Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java Thu Aug 19 17:32:49 2010 @@ -21,10 +21,10 @@ import java.lang.reflect.Type; import org.apache.mahout.clustering.canopy.Canopy; import org.apache.mahout.clustering.dirichlet.DirichletCluster; +import org.apache.mahout.clustering.dirichlet.JsonClusterModelAdapter; import org.apache.mahout.clustering.dirichlet.JsonModelAdapter; import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel; import org.apache.mahout.clustering.dirichlet.models.L1Model; -import org.apache.mahout.clustering.dirichlet.models.Model; import org.apache.mahout.clustering.dirichlet.models.NormalModel; import org.apache.mahout.clustering.dirichlet.models.SampledNormalModel; import org.apache.mahout.clustering.meanshift.MeanShiftCanopy; @@ -34,7 +34,6 @@ import org.apache.mahout.common.distance import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.SequentialAccessSparseVector; import org.apache.mahout.math.Vector; -import org.apache.mahout.math.VectorWritable; import com.google.gson.Gson; import com.google.gson.GsonBuilder; @@ -43,7 +42,7 @@ import com.google.gson.reflect.TypeToken public class TestClusterInterface extends MahoutTestCase { private static final Type MODEL_TYPE = new TypeToken<Model<Vector>>() {}.getType(); - private static final Type CLUSTER_TYPE = new TypeToken<DirichletCluster<Vector>>() {}.getType(); + private static final Type CLUSTER_TYPE = new TypeToken<DirichletCluster>() {}.getType(); private static final DistanceMeasure measure = new ManhattanDistanceMeasure(); public void testDirichletNormalModel() { @@ -106,7 +105,7 @@ public class TestClusterInterface extend double[] d = { 1.1, 2.2, 3.3 }; Vector m = new DenseVector(d); NormalModel model = new NormalModel(5, m, 0.75); - Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0); + Cluster cluster = new DirichletCluster(model, 35.0); String format = cluster.asFormatString(null); assertEquals("format", "C-5: nm{n=0 m=[1.100, 2.200, 3.300] sd=0.75}", format); } @@ -115,12 +114,12 @@ public class TestClusterInterface extend double[] d = { 1.1, 2.2, 3.3 }; Vector m = new DenseVector(d); NormalModel model = new NormalModel(5, m, 0.75); - Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0); + Cluster cluster = new DirichletCluster(model, 35.0); String json = cluster.asJsonString(); GsonBuilder builder = new GsonBuilder(); - builder.registerTypeAdapter(Model.class, new JsonModelAdapter()); + builder.registerTypeAdapter(Cluster.class, new JsonClusterModelAdapter()); Gson gson = builder.create(); - DirichletCluster<VectorWritable> result = gson.fromJson(json, CLUSTER_TYPE); + DirichletCluster result = gson.fromJson(json, CLUSTER_TYPE); assertNotNull("result null", result); assertEquals("model", cluster.asFormatString(null), result.asFormatString(null)); } @@ -129,7 +128,7 @@ public class TestClusterInterface extend double[] d = { 1.1, 2.2, 3.3 }; Vector m = new DenseVector(d); AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(5, m, m); - Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0); + Cluster cluster = new DirichletCluster(model, 35.0); String format = cluster.asFormatString(null); assertEquals("format", "C-5: asnm{n=0 m=[1.100, 2.200, 3.300] sd=[1.100, 2.200, 3.300]}", format); } @@ -138,13 +137,13 @@ public class TestClusterInterface extend double[] d = { 1.1, 2.2, 3.3 }; Vector m = new DenseVector(d); AsymmetricSampledNormalModel model = new AsymmetricSampledNormalModel(5, m, m); - Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0); + Cluster cluster = new DirichletCluster(model, 35.0); String json = cluster.asJsonString(); GsonBuilder builder = new GsonBuilder(); - builder.registerTypeAdapter(Model.class, new JsonModelAdapter()); + builder.registerTypeAdapter(Cluster.class, new JsonClusterModelAdapter()); Gson gson = builder.create(); - DirichletCluster<VectorWritable> result = gson.fromJson(json, CLUSTER_TYPE); + DirichletCluster result = gson.fromJson(json, CLUSTER_TYPE); assertNotNull("result null", result); assertEquals("model", cluster.asFormatString(null), result.asFormatString(null)); } @@ -153,7 +152,7 @@ public class TestClusterInterface extend double[] d = { 1.1, 2.2, 3.3 }; Vector m = new DenseVector(d); L1Model model = new L1Model(5, m); - Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0); + Cluster cluster = new DirichletCluster(model, 35.0); String format = cluster.asFormatString(null); assertEquals("format", "C-5: l1m{n=0 c=[1.100, 2.200, 3.300]}", format); } @@ -162,13 +161,13 @@ public class TestClusterInterface extend double[] d = { 1.1, 2.2, 3.3 }; Vector m = new DenseVector(d); L1Model model = new L1Model(5, m); - Cluster cluster = new DirichletCluster<VectorWritable>(model, 35.0); + Cluster cluster = new DirichletCluster(model, 35.0); String json = cluster.asJsonString(); GsonBuilder builder = new GsonBuilder(); - builder.registerTypeAdapter(Model.class, new JsonModelAdapter()); + builder.registerTypeAdapter(Cluster.class, new JsonClusterModelAdapter()); Gson gson = builder.create(); - DirichletCluster<VectorWritable> result = gson.fromJson(json, CLUSTER_TYPE); + DirichletCluster result = gson.fromJson(json, CLUSTER_TYPE); assertNotNull("result null", result); assertEquals("model", cluster.asFormatString(null), result.asFormatString(null)); }
Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestDirichletClustering.java Thu Aug 19 17:32:49 2010 @@ -20,8 +20,9 @@ package org.apache.mahout.clustering.dir import java.util.ArrayList; import java.util.List; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.Model; import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution; -import org.apache.mahout.clustering.dirichlet.models.Model; import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution; import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution; import org.apache.mahout.common.MahoutTestCase; @@ -71,7 +72,7 @@ public class TestDirichletClustering ext generateSamples(num, mx, my, sd, 2); } - private static void printResults(List<Model<VectorWritable>[]> result, + private static void printResults(List<Cluster[]> result, int significant) { int row = 0; for (Model<VectorWritable>[] r : result) { @@ -92,10 +93,10 @@ public class TestDirichletClustering ext generateSamples(30, 1, 0, 0.1); generateSamples(30, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new NormalModelDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } @@ -106,10 +107,10 @@ public class TestDirichletClustering ext generateSamples(30, 1, 0, 0.1); generateSamples(30, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new SampledNormalDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } @@ -120,10 +121,10 @@ public class TestDirichletClustering ext generateSamples(30, 1, 0, 0.1); generateSamples(30, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } @@ -134,10 +135,10 @@ public class TestDirichletClustering ext generateSamples(300, 1, 0, 0.1); generateSamples(300, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new NormalModelDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 20); assertNotNull(result); } @@ -148,10 +149,10 @@ public class TestDirichletClustering ext generateSamples(300, 1, 0, 0.1); generateSamples(300, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new SampledNormalDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 20); assertNotNull(result); } @@ -162,10 +163,10 @@ public class TestDirichletClustering ext generateSamples(300, 1, 0, 0.1); generateSamples(300, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 20); assertNotNull(result); } @@ -176,10 +177,10 @@ public class TestDirichletClustering ext generateSamples(3000, 1, 0, 0.1); generateSamples(3000, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new NormalModelDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 200); assertNotNull(result); } @@ -190,10 +191,10 @@ public class TestDirichletClustering ext generateSamples(3000, 1, 0, 0.1); generateSamples(3000, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 200); assertNotNull(result); } @@ -204,10 +205,10 @@ public class TestDirichletClustering ext generateSamples(3000, 1, 0, 0.1); generateSamples(3000, 0, 1, 0.1); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new SampledNormalDistribution(new VectorWritable( new DenseVector(2))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 200); assertNotNull(result); } @@ -218,10 +219,10 @@ public class TestDirichletClustering ext generateSamples(30, 1, 0, 0.1, 3); generateSamples(30, 0, 1, 0.1, 3); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new NormalModelDistribution(new VectorWritable( new DenseVector(3))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } @@ -232,10 +233,10 @@ public class TestDirichletClustering ext generateSamples(30, 1, 0, 0.1, 3); generateSamples(30, 0, 1, 0.1, 3); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new SampledNormalDistribution(new VectorWritable( new DenseVector(3))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } @@ -246,10 +247,10 @@ public class TestDirichletClustering ext generateSamples(30, 1, 0, 0.1, 3); generateSamples(30, 0, 1, 0.1, 3); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>( + DirichletClusterer dc = new DirichletClusterer( sampleData, new AsymmetricSampledNormalDistribution(new VectorWritable( new DenseVector(3))), 1.0, 10, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(30); + List<Cluster[]> result = dc.cluster(30); printResults(result, 2); assertNotNull(result); } Modified: mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/clustering/dirichlet/TestMapReduce.java Thu Aug 19 17:32:49 2010 @@ -30,9 +30,10 @@ import org.apache.hadoop.io.Text; import org.apache.hadoop.io.WritableComparable; import org.apache.hadoop.mapreduce.Mapper; import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.clustering.Model; import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalModel; -import org.apache.mahout.clustering.dirichlet.models.Model; import org.apache.mahout.clustering.dirichlet.models.NormalModel; import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution; import org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution; @@ -116,7 +117,7 @@ public class TestMapReduce extends Mahou /** Test the basic Mapper */ public void testMapper() throws Exception { generateSamples(10, 0, 0, 1); - DirichletState<VectorWritable> state = new DirichletState<VectorWritable>(new NormalModelDistribution(new VectorWritable(new DenseVector(2))), + DirichletState state = new DirichletState(new NormalModelDistribution(new VectorWritable(new DenseVector(2))), 5, 1); DirichletMapper mapper = new DirichletMapper(); @@ -140,7 +141,7 @@ public class TestMapReduce extends Mahou generateSamples(100, 2, 0, 1); generateSamples(100, 0, 2, 1); generateSamples(100, 2, 2, 1); - DirichletState<VectorWritable> state = new DirichletState<VectorWritable>(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), + DirichletState state = new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1); DirichletMapper mapper = new DirichletMapper(); @@ -156,14 +157,14 @@ public class TestMapReduce extends Mahou DirichletReducer reducer = new DirichletReducer(); reducer.setup(state); - DummyRecordWriter<Text, DirichletCluster<VectorWritable>> reduceWriter = new DummyRecordWriter<Text, DirichletCluster<VectorWritable>>(); - Reducer<Text, VectorWritable, Text, DirichletCluster<VectorWritable>>.Context reduceContext = DummyRecordWriter + DummyRecordWriter<Text, DirichletCluster> reduceWriter = new DummyRecordWriter<Text, DirichletCluster>(); + Reducer<Text, VectorWritable, Text, DirichletCluster>.Context reduceContext = DummyRecordWriter .build(reducer, conf, reduceWriter, Text.class, VectorWritable.class); for (Text key : mapWriter.getKeys()) { reducer.reduce(new Text(key), mapWriter.getValue(key), reduceContext); } - Model<VectorWritable>[] newModels = reducer.getNewModels(); + Cluster[] newModels = reducer.getNewModels(); state.update(newModels); } @@ -173,7 +174,7 @@ public class TestMapReduce extends Mahou generateSamples(100, 2, 0, 1); generateSamples(100, 0, 2, 1); generateSamples(100, 2, 2, 1); - DirichletState<VectorWritable> state = new DirichletState<VectorWritable>(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), + DirichletState state = new DirichletState(new SampledNormalDistribution(new VectorWritable(new DenseVector(2))), 20, 1.0); @@ -192,14 +193,14 @@ public class TestMapReduce extends Mahou DirichletReducer reducer = new DirichletReducer(); reducer.setup(state); - DummyRecordWriter<Text, DirichletCluster<VectorWritable>> reduceWriter = new DummyRecordWriter<Text, DirichletCluster<VectorWritable>>(); - Reducer<Text, VectorWritable, Text, DirichletCluster<VectorWritable>>.Context reduceContext = DummyRecordWriter + DummyRecordWriter<Text, DirichletCluster> reduceWriter = new DummyRecordWriter<Text, DirichletCluster>(); + Reducer<Text, VectorWritable, Text, DirichletCluster>.Context reduceContext = DummyRecordWriter .build(reducer, conf, reduceWriter, Text.class, VectorWritable.class); for (Text key : mapWriter.getKeys()) { reducer.reduce(new Text(key), mapWriter.getValue(key), reduceContext); } - Model<VectorWritable>[] newModels = reducer.getNewModels(); + Cluster[] newModels = reducer.getNewModels(); state.update(newModels); models.add(newModels); } @@ -221,9 +222,9 @@ public class TestMapReduce extends Mahou System.out.println(); } - private static void printResults(List<List<DirichletCluster<VectorWritable>>> clusters, int significant) { + private static void printResults(List<List<DirichletCluster>> clusters, int significant) { int row = 0; - for (List<DirichletCluster<VectorWritable>> r : clusters) { + for (List<DirichletCluster> r : clusters) { System.out.print("sample[" + row++ + "]= "); for (int k = 0; k < r.size(); k++) { Model<VectorWritable> model = r.get(k).getModel(); @@ -257,7 +258,7 @@ public class TestMapReduce extends Mahou DefaultOptionCreator.SEQUENTIAL_METHOD }; new DirichletDriver().run(args); // and inspect results - List<List<DirichletCluster<VectorWritable>>> clusters = new ArrayList<List<DirichletCluster<VectorWritable>>>(); + List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>(); Configuration conf = new Configuration(); conf.set(DirichletDriver.MODEL_FACTORY_KEY, "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution"); conf.set(DirichletDriver.MODEL_PROTOTYPE_KEY, "org.apache.mahout.math.DenseVector"); @@ -290,7 +291,7 @@ public class TestMapReduce extends Mahou optKey(DefaultOptionCreator.CLUSTERING_OPTION) }; new DirichletDriver().run(args); // and inspect results - List<List<DirichletCluster<VectorWritable>>> clusters = new ArrayList<List<DirichletCluster<VectorWritable>>>(); + List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>(); Configuration conf = new Configuration(); conf.set(DirichletDriver.MODEL_FACTORY_KEY, "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution"); conf.set(DirichletDriver.MODEL_PROTOTYPE_KEY, "org.apache.mahout.math.DenseVector"); @@ -322,7 +323,7 @@ public class TestMapReduce extends Mahou 0, false); // and inspect results - List<List<DirichletCluster<VectorWritable>>> clusters = new ArrayList<List<DirichletCluster<VectorWritable>>>(); + List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>(); Configuration conf = new Configuration(); conf.set(DirichletDriver.MODEL_FACTORY_KEY, "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution"); conf.set(DirichletDriver.MODEL_PROTOTYPE_KEY, "org.apache.mahout.math.DenseVector"); @@ -368,7 +369,7 @@ public class TestMapReduce extends Mahou 0, false); // and inspect results - List<List<DirichletCluster<VectorWritable>>> clusters = new ArrayList<List<DirichletCluster<VectorWritable>>>(); + List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>(); Configuration conf = new Configuration(); conf.set(DirichletDriver.MODEL_FACTORY_KEY, "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution"); conf.set(DirichletDriver.MODEL_PROTOTYPE_KEY, "org.apache.mahout.math.DenseVector"); @@ -414,7 +415,7 @@ public class TestMapReduce extends Mahou 0, false); // and inspect results - List<List<DirichletCluster<VectorWritable>>> clusters = new ArrayList<List<DirichletCluster<VectorWritable>>>(); + List<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>(); Configuration conf = new Configuration(); conf .set(DirichletDriver.MODEL_FACTORY_KEY, "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution"); @@ -470,10 +471,10 @@ public class TestMapReduce extends Mahou public void testClusterWritableSerialization() throws Exception { double[] m = { 1.1, 2.2, 3.3 }; - DirichletCluster<?> cluster = new DirichletCluster<VectorWritable>(new NormalModel(5, new DenseVector(m), 4), 10); + DirichletCluster cluster = new DirichletCluster(new NormalModel(5, new DenseVector(m), 4), 10); DataOutputBuffer out = new DataOutputBuffer(); cluster.write(out); - DirichletCluster<?> cluster2 = new DirichletCluster<VectorWritable>(); + DirichletCluster cluster2 = new DirichletCluster(); DataInputBuffer in = new DataInputBuffer(); in.reset(out.getData(), out.getLength()); cluster2.readFields(in); Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayCanopy.java Thu Aug 19 17:32:49 2010 @@ -17,20 +17,25 @@ package org.apache.mahout.clustering.display; +import java.awt.BasicStroke; +import java.awt.Color; import java.awt.Graphics; import java.awt.Graphics2D; +import java.util.List; import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.Cluster; import org.apache.mahout.clustering.canopy.CanopyDriver; import org.apache.mahout.common.HadoopUtil; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.math.DenseVector; class DisplayCanopy extends DisplayClustering { DisplayCanopy() { initialize(); - this.setTitle("Canopy Clusters (>" + (int) (getSignificance() * 100) + "% of population)"); + this.setTitle("Canopy Clusters (>" + (int) (significance * 100) + "% of population)"); } @Override @@ -39,6 +44,24 @@ class DisplayCanopy extends DisplayClust plotClusters((Graphics2D) g); } + protected static void plotClusters(Graphics2D g2) { + int cx = CLUSTERS.size() - 1; + for (List<Cluster> clusters : CLUSTERS) { + for (Cluster cluster : clusters) { + g2.setStroke(new BasicStroke(1)); + g2.setColor(Color.BLUE); + double[] t1 = { T1, T1 }; + plotEllipse(g2, cluster.getCenter(), new DenseVector(t1)); + double[] t2 = { T2, T2 }; + plotEllipse(g2, cluster.getCenter(), new DenseVector(t2)); + g2.setColor(COLORS[Math.min(DisplayClustering.COLORS.length - 1, cx)]); + g2.setStroke(new BasicStroke(cx == 0 ? 3 : 1)); + plotEllipse(g2, cluster.getCenter(), cluster.getRadius().times(3)); + } + cx--; + } + } + public static void main(String[] args) throws Exception { //SIGNIFICANCE = 0.05; Path samples = new Path("samples"); @@ -50,8 +73,8 @@ class DisplayCanopy extends DisplayClust writeSampleData(samples); //boolean b = true; //if (b) { - new CanopyDriver().buildClusters(samples, output, new ManhattanDistanceMeasure(), T1, T2, true); - loadClusters(output); + new CanopyDriver().buildClusters(samples, output, new ManhattanDistanceMeasure(), T1, T2, true); + loadClusters(output); //} else { // List<Vector> points = new ArrayList<Vector>(); // for (VectorWritable sample : SAMPLE_DATA) { Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayClustering.java Thu Aug 19 17:32:49 2010 @@ -70,6 +70,8 @@ public class DisplayClustering extends F protected static final double T1 = 3.0; protected static final double T2 = 2.8; + + protected static double significance = 0.05; protected static int res; // screen resolution @@ -298,12 +300,7 @@ public class DisplayClustering extends F } } - protected boolean isSignificant(Cluster cluster) { - return (double) cluster.getNumPoints() / SAMPLE_DATA.size() > getSignificance(); - } - - protected double getSignificance() { - return 0.05; + protected static boolean isSignificant(Cluster cluster) { + return (double) cluster.getNumPoints() / SAMPLE_DATA.size() > significance; } - } Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayDirichlet.java Thu Aug 19 17:32:49 2010 @@ -23,10 +23,9 @@ import java.util.ArrayList; import java.util.List; import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.ModelDistribution; import org.apache.mahout.clustering.dirichlet.DirichletClusterer; -import org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution; -import org.apache.mahout.clustering.dirichlet.models.Model; -import org.apache.mahout.clustering.dirichlet.models.ModelDistribution; +import org.apache.mahout.clustering.dirichlet.models.GaussianClusterDistribution; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.DenseVector; import org.apache.mahout.math.VectorWritable; @@ -39,8 +38,7 @@ public class DisplayDirichlet extends Di public DisplayDirichlet() { initialize(); - this.setTitle("Dirichlet Process Clusters - Normal Distribution (>" + - (int) (getSignificance() * 100) + "% of population)"); + this.setTitle("Dirichlet Process Clusters - Normal Distribution (>" + (int) (significance * 100) + "% of population)"); } // Override the paint() method @@ -50,15 +48,15 @@ public class DisplayDirichlet extends Di plotClusters((Graphics2D) g); } - protected static void printModels(Iterable<Model<VectorWritable>[]> results, int significant) { + protected static void printModels(List<Cluster[]> result, int significant) { int row = 0; StringBuilder models = new StringBuilder(); - for (Model<VectorWritable>[] r : results) { + for (Cluster[] r : result) { models.append("sample[").append(row++).append("]= "); for (int k = 0; k < r.length; k++) { - Model<VectorWritable> model = r[k]; + Cluster model = r[k]; if (model.count() > significant) { - models.append('m').append(k).append(model).append(", "); + models.append('m').append(k).append(model.asFormatString(null)).append(", "); } } models.append('\n'); @@ -67,25 +65,20 @@ public class DisplayDirichlet extends Di log.info(models.toString()); } - protected void generateResults(ModelDistribution<VectorWritable> modelDist, + protected static void generateResults(ModelDistribution<VectorWritable> modelDist, int numClusters, int numIterations, double alpha_0, int thin, int burnin) { - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(SAMPLE_DATA, - modelDist, - alpha_0, - numClusters, - thin, - burnin); - List<Model<VectorWritable>[]> result = dc.cluster(numIterations); + DirichletClusterer dc = new DirichletClusterer(SAMPLE_DATA, modelDist, alpha_0, numClusters, thin, burnin); + List<Cluster[]> result = dc.cluster(numIterations); printModels(result, burnin); - for (Model<VectorWritable>[] models : result) { + for (Cluster[] models : result) { List<Cluster> clusters = new ArrayList<Cluster>(); - for (Model<VectorWritable> cluster : models) { - if (isSignificant((Cluster)cluster)) { - clusters.add((Cluster)cluster); + for (Cluster cluster : models) { + if (isSignificant(cluster)) { + clusters.add((Cluster) cluster); } } CLUSTERS.add(clusters); @@ -96,7 +89,8 @@ public class DisplayDirichlet extends Di VectorWritable modelPrototype = new VectorWritable(new DenseVector(2)); // ModelDistribution<VectorWritable> modelDist = new NormalModelDistribution(modelPrototype); // ModelDistribution<VectorWritable> modelDist = new SampledNormalDistribution(modelPrototype); - ModelDistribution<VectorWritable> modelDist = new AsymmetricSampledNormalDistribution(modelPrototype); + // ModelDistribution<VectorWritable> modelDist = new AsymmetricSampledNormalDistribution(modelPrototype); + ModelDistribution<VectorWritable> modelDist = new GaussianClusterDistribution(modelPrototype); RandomUtils.useTestSeed(); generateSamples(); @@ -105,7 +99,8 @@ public class DisplayDirichlet extends Di int alpha_0 = 1; int thin = 3; int burnin = 5; - new DisplayDirichlet().generateResults(modelDist, numClusters, numIterations, alpha_0, thin, burnin); + generateResults(modelDist, numClusters, numIterations, alpha_0, thin, burnin); + new DisplayDirichlet(); } } Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayFuzzyKMeans.java Thu Aug 19 17:32:49 2010 @@ -32,7 +32,7 @@ class DisplayFuzzyKMeans extends Display DisplayFuzzyKMeans() { initialize(); - this.setTitle("Fuzzy k-Means Clusters (>" + (int) (getSignificance() * 100) + "% of population)"); + this.setTitle("Fuzzy k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); } // Override the paint() method Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayKMeans.java Thu Aug 19 17:32:49 2010 @@ -34,7 +34,7 @@ class DisplayKMeans extends DisplayClust DisplayKMeans() { initialize(); - this.setTitle("k-Means Clusters (>" + (int) (getSignificance() * 100) + "% of population)"); + this.setTitle("k-Means Clusters (>" + (int) (significance * 100) + "% of population)"); } public static void main(String[] args) throws Exception { Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/display/DisplayMeanShift.java Thu Aug 19 17:32:49 2010 @@ -49,7 +49,7 @@ final class DisplayMeanShift extends Dis private DisplayMeanShift() { initialize(); - this.setTitle("k-Means Clusters (>" + (int) (getSignificance() * 100) + "% of population)"); + this.setTitle("Mean Shift Canopy Clusters (>" + (int) (significance * 100) + "% of population)"); } @Override @@ -75,7 +75,7 @@ final class DisplayMeanShift extends Dis int i = 0; for (Cluster cluster : CLUSTERS.get(CLUSTERS.size()-1)) { MeanShiftCanopy canopy = (MeanShiftCanopy) cluster; - if (canopy.getBoundPoints().toList().size() >= getSignificance() * DisplayClustering.SAMPLE_DATA.size()) { + if (canopy.getBoundPoints().toList().size() >= significance * DisplayClustering.SAMPLE_DATA.size()) { g2.setColor(COLORS[Math.min(i++, DisplayClustering.COLORS.length - 1)]); int count = 0; Vector center = new DenseVector(2); @@ -92,15 +92,11 @@ final class DisplayMeanShift extends Dis } } - @Override - protected double getSignificance() { - return 0.02; - } - public static void main(String[] args) throws Exception { t1 = 1.5; t2 = 0.5; DistanceMeasure measure = new EuclideanDistanceMeasure(); + significance = 0.02; Path samples = new Path("samples"); Path output = new Path("output"); Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/Job.java Thu Aug 19 17:32:49 2010 @@ -28,10 +28,10 @@ import org.apache.commons.cli2.builder.A import org.apache.commons.cli2.builder.DefaultOptionBuilder; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.Model; import org.apache.mahout.clustering.dirichlet.DirichletCluster; import org.apache.mahout.clustering.dirichlet.DirichletDriver; import org.apache.mahout.clustering.dirichlet.DirichletMapper; -import org.apache.mahout.clustering.dirichlet.models.Model; import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution; import org.apache.mahout.clustering.syntheticcontrol.Constants; import org.apache.mahout.clustering.syntheticcontrol.canopy.InputDriver; @@ -72,8 +72,7 @@ public final class Job extends Dirichlet } @Override - public int run(String[] args) - throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException, + public int run(String[] args) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException, NoSuchMethodException, InvocationTargetException, InterruptedException { addInputOption(); addOutputOption(); @@ -145,9 +144,8 @@ public final class Job extends Dirichlet double alpha0, int numReducers, boolean emitMostLikely, - double threshold) - throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException, - NoSuchMethodException, InvocationTargetException, SecurityException, InterruptedException { + double threshold) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException, + NoSuchMethodException, InvocationTargetException, SecurityException, InterruptedException { Path directoryContainingConvertedInput = new Path(output, Constants.DIRECTORY_CONTAINING_CONVERTED_INPUT); InputDriver.runJob(input, directoryContainingConvertedInput, modelPrototype); DirichletDriver.runJob(directoryContainingConvertedInput, @@ -160,10 +158,11 @@ public final class Job extends Dirichlet numReducers, true, emitMostLikely, - threshold, false); + threshold, + false); // run ClusterDumper - ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-" + maxIterations), - new Path(output, "clusteredPoints")); + ClusterDumper clusterDumper = new ClusterDumper(new Path(output, "clusters-" + maxIterations), new Path(output, + "clusteredPoints")); clusterDumper.printClusters(null); } @@ -192,7 +191,7 @@ public final class Job extends Dirichlet int numIterations, int numModels, double alpha0) throws NoSuchMethodException, InvocationTargetException { - Collection<List<DirichletCluster<VectorWritable>>> clusters = new ArrayList<List<DirichletCluster<VectorWritable>>>(); + Collection<List<DirichletCluster>> clusters = new ArrayList<List<DirichletCluster>>(); Configuration conf = new Configuration(); conf.set(MODEL_FACTORY_KEY, modelDistribution); conf.set(NUM_CLUSTERS_KEY, Integer.toString(numModels)); @@ -215,10 +214,10 @@ public final class Job extends Dirichlet * @param significant * the minimum number of samples to enable printing a model */ - private static void printClusters(Iterable<List<DirichletCluster<VectorWritable>>> clusters, int significant) { + private static void printClusters(Iterable<List<DirichletCluster>> clusters, int significant) { int row = 0; StringBuilder result = new StringBuilder(); - for (List<DirichletCluster<VectorWritable>> r : clusters) { + for (List<DirichletCluster> r : clusters) { result.append("sample=").append(row++).append("]= "); for (int k = 0; k < r.size(); k++) { Model<VectorWritable> model = r.get(k).getModel(); Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/syntheticcontrol/dirichlet/NormalScModelDistribution.java Thu Aug 19 17:32:49 2010 @@ -17,8 +17,8 @@ package org.apache.mahout.clustering.syntheticcontrol.dirichlet; +import org.apache.mahout.clustering.Model; import org.apache.mahout.clustering.dirichlet.UncommonDistributions; -import org.apache.mahout.clustering.dirichlet.models.Model; import org.apache.mahout.clustering.dirichlet.models.NormalModel; import org.apache.mahout.clustering.dirichlet.models.NormalModelDistribution; import org.apache.mahout.math.DenseVector; Modified: mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java (original) +++ mahout/trunk/utils/src/main/java/org/apache/mahout/clustering/cdbw/CDbwDriver.java Thu Aug 19 17:32:49 2010 @@ -153,7 +153,7 @@ public final class CDbwDriver extends Ab SequenceFile.Writer writer = new SequenceFile.Writer(fs, conf, path, IntWritable.class, VectorWritable.class); while (reader.next(key, value)) { Cluster cluster = (Cluster) value; - if (!(cluster instanceof DirichletCluster<?>) || ((DirichletCluster<?>) cluster).getTotalCount() > 0) { + if (!(cluster instanceof DirichletCluster) || ((DirichletCluster) cluster).getTotalCount() > 0) { //System.out.println("C-" + cluster.getId() + ": " + ClusterBase.formatVector(cluster.getCenter(), null)); writer.append(new IntWritable(cluster.getId()), new VectorWritable(cluster.getCenter())); } Modified: mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java URL: http://svn.apache.org/viewvc/mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java?rev=987240&r1=987239&r2=987240&view=diff ============================================================================== --- mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java (original) +++ mahout/trunk/utils/src/test/java/org/apache/mahout/clustering/dirichlet/TestL1ModelClustering.java Thu Aug 19 17:32:49 2010 @@ -33,8 +33,10 @@ import org.apache.lucene.index.IndexRead import org.apache.lucene.index.IndexWriter; import org.apache.lucene.store.RAMDirectory; import org.apache.lucene.util.Version; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.Model; +import org.apache.mahout.clustering.dirichlet.models.DistanceMeasureClusterDistribution; import org.apache.mahout.clustering.dirichlet.models.L1ModelDistribution; -import org.apache.mahout.clustering.dirichlet.models.Model; import org.apache.mahout.common.MahoutTestCase; import org.apache.mahout.common.RandomUtils; import org.apache.mahout.math.Vector; @@ -49,18 +51,18 @@ import org.apache.mahout.utils.vectors.l import org.junit.Before; public class TestL1ModelClustering extends MahoutTestCase { - + private class MapElement implements Comparable<MapElement> { - + MapElement(double pdf, String doc) { this.pdf = pdf; this.doc = doc; } - + private final Double pdf; - + private final String doc; - + @Override // reverse compare to sort in reverse order public int compareTo(MapElement e) { @@ -72,53 +74,45 @@ public class TestL1ModelClustering exten return 0; } } - + @Override public String toString() { return pdf.toString() + ' ' + doc; } - + } - - private static final String[] DOCS = {"The quick red fox jumped over the lazy brown dogs.", - "The quick brown fox jumped over the lazy red dogs.", - "The quick red cat jumped over the lazy brown dogs.", - "The quick brown cat jumped over the lazy red dogs.", - "Mary had a little lamb whose fleece was white as snow.", - "Moby Dick is a story of a whale and a man obsessed.", - "The robber wore a black fleece jacket and a baseball cap.", - "The English Springer Spaniel is the best of all dogs."}; - + + private static final String[] DOCS = { "The quick red fox jumped over the lazy brown dogs.", + "The quick brown fox jumped over the lazy red dogs.", "The quick red cat jumped over the lazy brown dogs.", + "The quick brown cat jumped over the lazy red dogs.", "Mary had a little lamb whose fleece was white as snow.", + "Moby Dick is a story of a whale and a man obsessed.", "The robber wore a black fleece jacket and a baseball cap.", + "The English Springer Spaniel is the best of all dogs." }; + private List<VectorWritable> sampleData; - - private static final String[] DOCS2 = {"The quick red fox jumped over the lazy brown dogs.", - "The quick brown fox jumped over the lazy red dogs.", - "The quick red cat jumped over the lazy brown dogs.", - "The quick brown cat jumped over the lazy red dogs.", - "Mary had a little lamb whose fleece was white as snow.", - "Mary had a little goat whose fleece was white as snow.", - "Mary had a little lamb whose fleece was black as tar.", - "Dick had a little goat whose fleece was white as snow.", - "Moby Dick is a story of a whale and a man obsessed.", - "Moby Bob is a story of a walrus and a man obsessed.", - "Moby Dick is a story of a whale and a crazy man.", - "The robber wore a black fleece jacket and a baseball cap.", - "The robber wore a red fleece jacket and a baseball cap.", - "The robber wore a white fleece jacket and a baseball cap.", - "The English Springer Spaniel is the best of all dogs."}; - + + private static final String[] DOCS2 = { "The quick red fox jumped over the lazy brown dogs.", + "The quick brown fox jumped over the lazy red dogs.", "The quick red cat jumped over the lazy brown dogs.", + "The quick brown cat jumped over the lazy red dogs.", "Mary had a little lamb whose fleece was white as snow.", + "Mary had a little goat whose fleece was white as snow.", "Mary had a little lamb whose fleece was black as tar.", + "Dick had a little goat whose fleece was white as snow.", "Moby Dick is a story of a whale and a man obsessed.", + "Moby Bob is a story of a walrus and a man obsessed.", "Moby Dick is a story of a whale and a crazy man.", + "The robber wore a black fleece jacket and a baseball cap.", "The robber wore a red fleece jacket and a baseball cap.", + "The robber wore a white fleece jacket and a baseball cap.", "The English Springer Spaniel is the best of all dogs." }; + @Override @Before public void setUp() throws Exception { super.setUp(); RandomUtils.useTestSeed(); } - + private void getSampleData(String[] docs2) throws IOException { sampleData = new ArrayList<VectorWritable>(); RAMDirectory directory = new RAMDirectory(); - IndexWriter writer = new IndexWriter(directory, new StandardAnalyzer(Version.LUCENE_CURRENT), true, - IndexWriter.MaxFieldLength.UNLIMITED); + IndexWriter writer = new IndexWriter(directory, + new StandardAnalyzer(Version.LUCENE_CURRENT), + true, + IndexWriter.MaxFieldLength.UNLIMITED); for (int i = 0; i < docs2.length; i++) { Document doc = new Document(); Field id = new Field("id", "doc_" + i, Field.Store.YES, Field.Index.NOT_ANALYZED_NO_NORMS); @@ -134,7 +128,7 @@ public class TestL1ModelClustering exten TermInfo termInfo = new CachedTermInfo(reader, "content", 1, 100); VectorMapper mapper = new TFDFMapper(reader, weight, termInfo); LuceneIterable iterable = new LuceneIterable(reader, "id", "content", mapper); - + int i = 0; for (Vector vector : iterable) { Assert.assertNotNull(vector); @@ -142,7 +136,7 @@ public class TestL1ModelClustering exten sampleData.add(new VectorWritable(vector)); } } - + private static String formatVector(Vector v) { StringBuilder buf = new StringBuilder(); int nzero = 0; @@ -169,27 +163,27 @@ public class TestL1ModelClustering exten buf.append(']'); return buf.toString(); } - - private static void printSamples(List<Model<VectorWritable>[]> result, int significant) { + + private static void printSamples(List<Cluster[]> result, int significant) { int row = 0; - for (Model<VectorWritable>[] r : result) { + for (Cluster[] r : result) { int sig = 0; - for (Model<VectorWritable> model : r) { + for (Cluster model : r) { if (model.count() > significant) { sig++; } } System.out.print("sample[" + row++ + "] (" + sig + ")= "); - for (Model<VectorWritable> model : r) { + for (Cluster model : r) { if (model.count() > significant) { - System.out.print(model.toString() + ", "); + System.out.print(model.asFormatString(null) + ", "); } } System.out.println(); } System.out.println(); } - + private void printClusters(Model<VectorWritable>[] models, List<VectorWritable> samples, String[] docs) { for (int m = 0; m < models.length; m++) { Model<VectorWritable> model = models[m]; @@ -198,7 +192,7 @@ public class TestL1ModelClustering exten continue; } System.out.println("Model[" + m + "] had " + count + " hits (!) and " + (samples.size() - count) - + " misses (? in pdf order) during the last iteration:"); + + " misses (? in pdf order) during the last iteration:"); MapElement[] map = new MapElement[samples.size()]; // sort the samples by pdf for (int i = 0; i < samples.size(); i++) { @@ -217,27 +211,55 @@ public class TestL1ModelClustering exten } } } - + public void testDocs() throws Exception { System.out.println("testDocs"); getSampleData(DOCS); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, - new L1ModelDistribution(sampleData.get(0)), 1.0, 15, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(10); + DirichletClusterer dc = new DirichletClusterer(sampleData, new L1ModelDistribution(sampleData.get(0)), 1.0, 15, 1, 0); + List<Cluster[]> result = dc.cluster(10); + Assert.assertNotNull(result); + printSamples(result, 0); + printClusters(result.get(result.size() - 1), sampleData, DOCS); + } + + public void testDMDocs() throws Exception { + System.out.println("DM testDocs"); + getSampleData(DOCS); + DirichletClusterer dc = new DirichletClusterer(sampleData, + new DistanceMeasureClusterDistribution(sampleData.get(0)), + 1.0, + 15, + 1, + 0); + List<Cluster[]> result = dc.cluster(10); Assert.assertNotNull(result); printSamples(result, 0); printClusters(result.get(result.size() - 1), sampleData, DOCS); } - + public void testDocs2() throws Exception { System.out.println("testDocs2"); getSampleData(DOCS2); - DirichletClusterer<VectorWritable> dc = new DirichletClusterer<VectorWritable>(sampleData, - new L1ModelDistribution(sampleData.get(0)), 1.0, 15, 1, 0); - List<Model<VectorWritable>[]> result = dc.cluster(10); + DirichletClusterer dc = new DirichletClusterer(sampleData, new L1ModelDistribution(sampleData.get(0)), 1.0, 15, 1, 0); + List<Cluster[]> result = dc.cluster(10); Assert.assertNotNull(result); printSamples(result, 0); printClusters(result.get(result.size() - 1), sampleData, DOCS2); } - + + public void testDMDocs2() throws Exception { + System.out.println("DM testDocs2"); + getSampleData(DOCS2); + DirichletClusterer dc = new DirichletClusterer(sampleData, + new DistanceMeasureClusterDistribution(sampleData.get(0)), + 1.0, + 15, + 1, + 0); + List<Cluster[]> result = dc.cluster(10); + Assert.assertNotNull(result); + printSamples(result, 0); + printClusters(result.get(result.size() - 1), sampleData, DOCS2); + } + }
