Author: jeastman Date: Mon Jan 18 19:23:25 2010 New Revision: 900519 URL: http://svn.apache.org/viewvc?rev=900519&view=rev Log: MAHOUT-251
- DirichletDriver, DirichletJob - added command line arguments for modelPrototypeClass and prototypeSize - Display2dASNDirichlet, DisplayASNDirichlet, DisplayASNOutputState, DisplayDirichlet, DisplayNDirichlet, DisplayOutputState, DisplaySNDirichlet - added modelProtype initializations to model distributions - math/SquareRootFunction - changed computation from abs() to sqrt() as the previous had crept in during recent refactoring and was completely horking the Dirichlet std() calculations - Cluster - changed std calculation to remove size=2 assumption - NormalModel - changed computeParameters to use std.size() vs s1.size() for uniformity - AsymmetricSampledNormalModel - changed comment in computeParameters all tests run Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SquareRootFunction.java Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletDriver.java Mon Jan 18 19:23:25 2010 @@ -88,11 +88,20 @@ abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()) .withDescription("The ModelDistribution class name.").create(); + Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(true).withShortName("p").withArgument( + abuilder.withName("prototypeClass").withMinimum(1).withMaximum(1).create()).withDescription( + "The ModelDistribution prototype Vector class name.").create(); + + Option sizeOpt = obuilder.withLongName("prototypeSize").withRequired(true).withShortName("s").withArgument( + abuilder.withName("prototypeSize").withMinimum(1).withMaximum(1).create()).withDescription( + "The ModelDistribution prototype Vector size.").create(); + Option numRedOpt = obuilder.withLongName("maxRed").withRequired(true).withShortName("r").withArgument( abuilder.withName("maxRed").withMinimum(1).withMaximum(1).create()).withDescription("The number of reduce tasks.").create(); Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(modelOpt).withOption( - maxIterOpt).withOption(mOpt).withOption(topicsOpt).withOption(helpOpt).withOption(numRedOpt).create(); + prototypeOpt).withOption(sizeOpt).withOption(maxIterOpt).withOption(mOpt).withOption(topicsOpt).withOption(helpOpt) + .withOption(numRedOpt).create(); try { Parser parser = new Parser(); @@ -106,11 +115,13 @@ String input = cmdLine.getValue(inputOpt).toString(); String output = cmdLine.getValue(outputOpt).toString(); String modelFactory = cmdLine.getValue(modelOpt).toString(); + String modelPrototype = cmdLine.getValue(prototypeOpt).toString(); + int prototypeSize = Integer.parseInt(cmdLine.getValue(sizeOpt).toString()); int numReducers = Integer.parseInt(cmdLine.getValue(numRedOpt).toString()); int numModels = Integer.parseInt(cmdLine.getValue(topicsOpt).toString()); int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString()); double alpha_0 = Double.parseDouble(cmdLine.getValue(mOpt).toString()); - runJob(input, output, modelFactory, numModels, maxIterations, alpha_0, numReducers); + runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels, maxIterations, alpha_0, numReducers); } catch (OptionException e) { log.error("Exception parsing command line: ", e); CommandLineUtil.printHelp(group); Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletJob.java Mon Jan 18 19:23:25 2010 @@ -44,8 +44,8 @@ private DirichletJob() { } - public static void main(String[] args) throws IOException, - ClassNotFoundException, InstantiationException, IllegalAccessException, SecurityException, IllegalArgumentException, NoSuchMethodException, InvocationTargetException { + public static void main(String[] args) throws IOException, ClassNotFoundException, InstantiationException, + IllegalAccessException, SecurityException, IllegalArgumentException, NoSuchMethodException, InvocationTargetException { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); GroupBuilder gbuilder = new GroupBuilder(); @@ -56,16 +56,25 @@ Option topicsOpt = DefaultOptionCreator.kOption().create(); Option helpOpt = DefaultOptionCreator.helpOption(); - Option mOpt = obuilder.withLongName("alpha").withRequired(true).withShortName("m"). - withArgument(abuilder.withName("alpha").withMinimum(1).withMaximum(1).create()). - withDescription("The alpha0 value for the DirichletDistribution.").create(); - - Option modelOpt = obuilder.withLongName("modelClass").withRequired(true).withShortName("d"). - withArgument(abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()). - withDescription("The ModelDistribution class name.").create(); - - Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(modelOpt). - withOption(maxIterOpt).withOption(mOpt).withOption(topicsOpt).withOption(helpOpt).create(); + Option mOpt = obuilder.withLongName("alpha").withRequired(true).withShortName("m").withArgument( + abuilder.withName("alpha").withMinimum(1).withMaximum(1).create()).withDescription( + "The alpha0 value for the DirichletDistribution.").create(); + + Option modelOpt = obuilder.withLongName("modelClass").withRequired(true).withShortName("d").withArgument( + abuilder.withName("modelClass").withMinimum(1).withMaximum(1).create()) + .withDescription("The ModelDistribution class name.").create(); + + Option prototypeOpt = obuilder.withLongName("modelPrototypeClass").withRequired(true).withShortName("p").withArgument( + abuilder.withName("prototypeClass").withMinimum(1).withMaximum(1).create()).withDescription( + "The ModelDistribution prototype Vector class name.").create(); + + Option sizeOpt = obuilder.withLongName("prototypeSize").withRequired(true).withShortName("s").withArgument( + abuilder.withName("prototypeSize").withMinimum(1).withMaximum(1).create()).withDescription( + "The ModelDistribution prototype Vector size.").create(); + + Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(modelOpt).withOption( + prototypeOpt).withOption(sizeOpt).withOption(maxIterOpt).withOption(mOpt).withOption(topicsOpt).withOption(helpOpt) + .create(); try { Parser parser = new Parser(); @@ -79,10 +88,12 @@ String input = cmdLine.getValue(inputOpt).toString(); String output = cmdLine.getValue(outputOpt).toString(); String modelFactory = cmdLine.getValue(modelOpt).toString(); + String modelPrototype = cmdLine.getValue(prototypeOpt).toString(); + int prototypeSize = Integer.parseInt(cmdLine.getValue(sizeOpt).toString()); int numModels = Integer.parseInt(cmdLine.getValue(topicsOpt).toString()); int maxIterations = Integer.parseInt(cmdLine.getValue(maxIterOpt).toString()); double alpha_0 = Double.parseDouble(cmdLine.getValue(mOpt).toString()); - runJob(input, output, modelFactory, numModels, maxIterations, alpha_0); + runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels, maxIterations, alpha_0); } catch (OptionException e) { log.error("Exception parsing command line: ", e); CommandLineUtil.printHelp(group); @@ -96,6 +107,8 @@ * @param input the directory pathname for input points * @param output the directory pathname for output points * @param modelFactory the ModelDistribution class name + * @param modelPrototype the Vector class name used by the modelFactory + * @param prototypeSize the size of the prototype vector * @param numModels the number of Models * @param maxIterations the maximum number of iterations * @param alpha_0 the alpha0 value for the DirichletDistribution @@ -104,9 +117,8 @@ * @throws IllegalArgumentException * @throws SecurityException */ - public static void runJob(String input, String output, String modelFactory, - int numModels, int maxIterations, double alpha_0) - throws IOException, ClassNotFoundException, InstantiationException, + public static void runJob(String input, String output, String modelFactory, String modelPrototype, int prototypeSize, + int numModels, int maxIterations, double alpha_0) throws IOException, ClassNotFoundException, InstantiationException, IllegalAccessException, SecurityException, IllegalArgumentException, NoSuchMethodException, InvocationTargetException { // delete the output directory Configuration conf = new JobConf(DirichletJob.class); @@ -116,7 +128,6 @@ fs.delete(outPath, true); } fs.mkdirs(outPath); - DirichletDriver.runJob(input, output, modelFactory, numModels, maxIterations, - alpha_0, 1); + DirichletDriver.runJob(input, output, modelFactory, modelPrototype, prototypeSize, numModels, maxIterations, alpha_0, 1); } } Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/AsymmetricSampledNormalModel.java Mon Jan 18 19:23:25 2010 @@ -93,7 +93,7 @@ return; } mean = s1.divide(s0); - // compute the two component stds + // compute the component stds if (s0 > 1) { stdDev = s2.times(s0).minus(s1.times(s1)) .assign(new SquareRootFunction()).divide(s0); Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/models/NormalModel.java Mon Jan 18 19:23:25 2010 @@ -100,7 +100,7 @@ if (s0 > 1) { Vector std = s2.times(s0).minus(s1.times(s1)).assign( new SquareRootFunction()).divide(s0); - stdDev = std.zSum() / s1.size(); + stdDev = std.zSum() / std.size(); } else { stdDev = Double.MIN_VALUE; } Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/kmeans/Cluster.java Mon Jan 18 19:23:25 2010 @@ -244,7 +244,7 @@ Vector stds = pointSquaredTotal.times(getNumPoints()).minus( getPointTotal().times(getPointTotal())) .assign(new SquareRootFunction()).divide(getNumPoints()); - return stds.zSum() / 2; + return stds.zSum() / stds.size(); } } Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/Display2dASNDirichlet.java Mon Jan 18 19:23:25 2010 @@ -32,8 +32,8 @@ class Display2dASNDirichlet extends DisplayDirichlet { Display2dASNDirichlet() { initialize(); - this.setTitle("Dirichlet Process Clusters - 2-d Asymmetric Sampled Normal Distribution (>" - + (int) (significance * 100) + "% of population)"); + this.setTitle("Dirichlet Process Clusters - 2-d Asymmetric Sampled Normal Distribution (>" + (int) (significance * 100) + + "% of population)"); } @Override @@ -63,6 +63,6 @@ } private static void generateResults() { - DisplayDirichlet.generateResults(new AsymmetricSampledNormalDistribution()); + DisplayDirichlet.generateResults(new AsymmetricSampledNormalDistribution(new VectorWritable(new DenseVector(2)))); } } Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNDirichlet.java Mon Jan 18 19:23:25 2010 @@ -64,6 +64,6 @@ } static void generateResults() { - DisplayDirichlet.generateResults(new AsymmetricSampledNormalDistribution()); + DisplayDirichlet.generateResults(new AsymmetricSampledNormalDistribution(new VectorWritable(new DenseVector(2)))); } } Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayASNOutputState.java Mon Jan 18 19:23:25 2010 @@ -97,8 +97,10 @@ conf .set(DirichletDriver.MODEL_FACTORY_KEY, "org.apache.mahout.clustering.dirichlet.models.AsymmetricSampledNormalDistribution"); - conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20)); - conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0)); + conf.set(DirichletDriver.MODEL_PROTOTYPE_KEY, "org.apache.mahout.math.DenseVector"); + conf.set(DirichletDriver.PROTOTYPE_SIZE_KEY, "2"); + conf.set(DirichletDriver.NUM_CLUSTERS_KEY, "20"); + conf.set(DirichletDriver.ALPHA_0_KEY, "1.0"); File f = new File("output"); for (File g : f.listFiles()) { conf.set(DirichletDriver.STATE_IN_KEY, g.getCanonicalPath()); @@ -115,7 +117,7 @@ } static void generateResults() { - DisplayDirichlet.generateResults(new NormalModelDistribution()); + DisplayDirichlet.generateResults(new NormalModelDistribution(new VectorWritable(new DenseVector(2)))); } } Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayDirichlet.java Mon Jan 18 19:23:25 2010 @@ -140,10 +140,10 @@ } /** - * Plot the points on the graphics context + * Draw a rectangle on the graphics context * @param g2 a Graphics2D context - * @param v a Vector of rectangle centers - * @param dv a Vector of rectangle sizes + * @param v a Vector of rectangle center + * @param dv a Vector of rectangle dimensions */ public static void plotRectangle(Graphics2D g2, Vector v, Vector dv) { double[] flip = { 1, -1 }; @@ -157,10 +157,10 @@ } /** - * Plot the points on the graphics context + * Draw an ellipse on the graphics context * @param g2 a Graphics2D context - * @param v a Vector of rectangle centers - * @param dv a Vector of rectangle sizes + * @param v a Vector of ellipse center + * @param dv a Vector of ellipse dimensions */ public static void plotEllipse(Graphics2D g2, Vector v, Vector dv) { double[] flip = { 1, -1 }; Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayNDirichlet.java Mon Jan 18 19:23:25 2010 @@ -63,6 +63,6 @@ } static void generateResults() { - DisplayDirichlet.generateResults(new NormalModelDistribution()); + DisplayDirichlet.generateResults(new NormalModelDistribution(new VectorWritable(new DenseVector(2)))); } } Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplayOutputState.java Mon Jan 18 19:23:25 2010 @@ -92,8 +92,10 @@ JobConf conf = new JobConf(KMeansDriver.class); conf.set(DirichletDriver.MODEL_FACTORY_KEY, "org.apache.mahout.clustering.dirichlet.models.SampledNormalDistribution"); - conf.set(DirichletDriver.NUM_CLUSTERS_KEY, Integer.toString(20)); - conf.set(DirichletDriver.ALPHA_0_KEY, Double.toString(1.0)); + conf.set(DirichletDriver.MODEL_PROTOTYPE_KEY, "org.apache.mahout.math.DenseVector"); + conf.set(DirichletDriver.PROTOTYPE_SIZE_KEY, "2"); + conf.set(DirichletDriver.NUM_CLUSTERS_KEY, "20"); + conf.set(DirichletDriver.ALPHA_0_KEY, "1.0"); File f = new File("output"); for (File g : f.listFiles()) { conf.set(DirichletDriver.STATE_IN_KEY, g.getCanonicalPath()); @@ -110,6 +112,6 @@ } static void generateResults() { - DisplayDirichlet.generateResults(new NormalModelDistribution()); + DisplayDirichlet.generateResults(new NormalModelDistribution(new VectorWritable(new DenseVector(2)))); } } Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/clustering/dirichlet/DisplaySNDirichlet.java Mon Jan 18 19:23:25 2010 @@ -63,6 +63,6 @@ } static void generateResults() { - DisplayDirichlet.generateResults(new SampledNormalDistribution()); + DisplayDirichlet.generateResults(new SampledNormalDistribution(new VectorWritable(new DenseVector(2)))); } } Modified: lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SquareRootFunction.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SquareRootFunction.java?rev=900519&r1=900518&r2=900519&view=diff ============================================================================== --- lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SquareRootFunction.java (original) +++ lucene/mahout/trunk/math/src/main/java/org/apache/mahout/math/SquareRootFunction.java Mon Jan 18 19:23:25 2010 @@ -21,7 +21,7 @@ @Override public double apply(double arg1) { - return Math.abs(arg1); + return Math.sqrt(arg1); } }