Author: jeastman Date: Mon Feb 8 23:11:51 2010 New Revision: 907842 URL: http://svn.apache.org/viewvc?rev=907842&view=rev Log: MAHOUT-276
- added alpha_0 parameter to rDirichlet and incorporated into rBeta arguments - passed alpha_0 argument in DirichletMapper and DirichletState calls to rDirichlet - removed totalCount = alpha_0/k initialization in DirichletState all tests still run and seem to produce reasonable outputs Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java?rev=907842&r1=907841&r2=907842&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletMapper.java Mon Feb 8 23:11:51 2010 @@ -86,9 +86,10 @@ String alpha_0 = job.get(DirichletDriver.ALPHA_0_KEY); try { + double alpha = Double.parseDouble(alpha_0); DirichletState<VectorWritable> state = DirichletDriver.createState( modelFactory, modelPrototype, Integer.parseInt(prototypeSize), - Integer.parseInt(numClusters), Double.parseDouble(alpha_0)); + Integer.parseInt(numClusters), alpha); Path path = new Path(statePath); FileSystem fs = FileSystem.get(path.toUri(), job); FileStatus[] status = fs.listStatus(path, new OutputLogFilter()); @@ -108,7 +109,7 @@ } } // TODO: with more than one mapper, they will all have different mixtures. Will this matter? - state.setMixture(UncommonDistributions.rDirichlet(state.totalCounts())); + state.setMixture(UncommonDistributions.rDirichlet(state.totalCounts(), alpha)); return state; } catch (ClassNotFoundException e) { throw new IllegalStateException(e); Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java?rev=907842&r1=907841&r2=907842&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/DirichletState.java Mon Feb 8 23:11:51 2010 @@ -35,21 +35,20 @@ private Vector mixture; // the mixture vector - private double offset; // alpha_0 / numClusters + private double alpha_0; // alpha_0 public DirichletState(ModelDistribution<O> modelFactory, int numClusters, double alpha_0, int thin, int burnin) { this.numClusters = numClusters; this.modelFactory = modelFactory; - // initialize totalCounts - offset = alpha_0 / numClusters; + this.alpha_0 = alpha_0; // sample initial prior models clusters = new ArrayList<DirichletCluster<O>>(); for (Model<O> m : modelFactory.sampleFromPrior(numClusters)) { - clusters.add(new DirichletCluster<O>(m, offset)); + clusters.add(new DirichletCluster<O>(m, 0.0)); } // sample the mixture parameters from a Dirichlet distribution on the totalCounts - mixture = UncommonDistributions.rDirichlet(totalCounts()); + mixture = UncommonDistributions.rDirichlet(totalCounts(), alpha_0); } public DirichletState() { @@ -87,14 +86,6 @@ this.mixture = mixture; } - public double getOffset() { - return offset; - } - - public void setOffset(double offset) { - this.offset = offset; - } - public Vector totalCounts() { Vector result = new DenseVector(numClusters); for (int i = 0; i < numClusters; i++) { @@ -115,7 +106,7 @@ clusters.get(i).setModel(newModels[i]); } // update the mixture - mixture = UncommonDistributions.rDirichlet(totalCounts()); + mixture = UncommonDistributions.rDirichlet(totalCounts(), alpha_0); } /** @@ -131,6 +122,7 @@ return mix * pdf; } + @SuppressWarnings("unchecked") public Model<O>[] getModels() { Model<O>[] result = (Model<O>[]) new Model[numClusters]; for (int i = 0; i < numClusters; i++) { Modified: lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java?rev=907842&r1=907841&r2=907842&view=diff ============================================================================== --- lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java (original) +++ lucene/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/dirichlet/UncommonDistributions.java Mon Feb 8 23:11:51 2010 @@ -228,25 +228,26 @@ } /** - * Sample from a Dirichlet distribution over the given alpha, returning a vector of probabilities using a + * Sample from a Dirichlet distribution, returning a vector of probabilities using a * stick-breaking algorithm * - * @param alpha an unnormalized count Vector + * @param totalCounts an unnormalized count Vector + * @param alpha_0 a double * @return a Vector of probabilities */ - public static Vector rDirichlet(Vector alpha) { - Vector r = alpha.like(); - double total = alpha.zSum(); - double remainder = 1; - for (int i = 0; i < r.size(); i++) { - double a = alpha.get(i); - total -= a; - double beta = rBeta(a, Math.max(0, total)); - double p = beta * remainder; - r.set(i, p); - remainder -= p; + public static Vector rDirichlet(Vector totalCounts, double alpha_0) { + Vector pi = totalCounts.like(); + double total = totalCounts.zSum(); + double remainder = 1.0; + for (int k = 0; k < pi.size(); k++) { + double countK = totalCounts.get(k); + total -= countK; + double betaK = rBeta(1 + countK, Math.max(0, alpha_0 + total)); + double piK = betaK * remainder; + pi.set(k, piK); + remainder -= piK; } - return r; + return pi; } }