http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java new file mode 100644 index 0000000..516177f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java @@ -0,0 +1,219 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.io.Closeables; + +/** + * This is a clustering iterator which works with a set of Vector data and a prior ClusterClassifier which has been + * initialized with a set of models. Its implementation is algorithm-neutral and works for any iterative clustering + * algorithm (currently k-means and fuzzy-k-means) that processes all the input vectors in each iteration. + * The cluster classifier is configured with a ClusteringPolicy to select the desired clustering algorithm. + */ +public final class ClusterIterator { + + public static final String PRIOR_PATH_KEY = "org.apache.mahout.clustering.prior.path"; + + private ClusterIterator() { + } + + /** + * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations + * + * @param data + * a {@code List<Vector>} of input vectors + * @param classifier + * a prior ClusterClassifier + * @param numIterations + * the int number of iterations to perform + * + * @return the posterior ClusterClassifier + */ + public static ClusterClassifier iterate(Iterable<Vector> data, ClusterClassifier classifier, int numIterations) { + ClusteringPolicy policy = classifier.getPolicy(); + for (int iteration = 1; iteration <= numIterations; iteration++) { + for (Vector vector : data) { + // update the policy based upon the prior + policy.update(classifier); + // classification yields probabilities + Vector probabilities = classifier.classify(vector); + // policy selects weights for models given those probabilities + Vector weights = policy.select(probabilities); + // training causes all models to observe data + for (Vector.Element e : weights.nonZeroes()) { + int index = e.index(); + classifier.train(index, vector, weights.get(index)); + } + } + // compute the posterior models + classifier.close(); + } + return classifier; + } + + /** + * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations using a sequential + * implementation + * + * @param conf + * the Configuration + * @param inPath + * a Path to input VectorWritables + * @param priorPath + * a Path to the prior classifier + * @param outPath + * a Path of output directory + * @param numIterations + * the int number of iterations to perform + */ + public static void iterateSeq(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations) + throws IOException { + ClusterClassifier classifier = new ClusterClassifier(); + classifier.readFromSeqFiles(conf, priorPath); + Path clustersOut = null; + int iteration = 1; + while (iteration <= numIterations) { + for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(inPath, PathType.LIST, + PathFilters.logsCRCFilter(), conf)) { + Vector vector = vw.get(); + // classification yields probabilities + Vector probabilities = classifier.classify(vector); + // policy selects weights for models given those probabilities + Vector weights = classifier.getPolicy().select(probabilities); + // training causes all models to observe data + for (Vector.Element e : weights.nonZeroes()) { + int index = e.index(); + classifier.train(index, vector, weights.get(index)); + } + } + // compute the posterior models + classifier.close(); + // update the policy + classifier.getPolicy().update(classifier); + // output the classifier + clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration); + classifier.writeToSeqFiles(clustersOut); + FileSystem fs = FileSystem.get(outPath.toUri(), conf); + iteration++; + if (isConverged(clustersOut, conf, fs)) { + break; + } + } + Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX); + FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn); + } + + /** + * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations using a mapreduce + * implementation + * + * @param conf + * the Configuration + * @param inPath + * a Path to input VectorWritables + * @param priorPath + * a Path to the prior classifier + * @param outPath + * a Path of output directory + * @param numIterations + * the int number of iterations to perform + */ + public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations) + throws IOException, InterruptedException, ClassNotFoundException { + ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath); + Path clustersOut = null; + int iteration = 1; + while (iteration <= numIterations) { + conf.set(PRIOR_PATH_KEY, priorPath.toString()); + + String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath; + Job job = new Job(conf, jobName); + job.setMapOutputKeyClass(IntWritable.class); + job.setMapOutputValueClass(ClusterWritable.class); + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(ClusterWritable.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(CIMapper.class); + job.setReducerClass(CIReducer.class); + + FileInputFormat.addInputPath(job, inPath); + clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration); + priorPath = clustersOut; + FileOutputFormat.setOutputPath(job, clustersOut); + + job.setJarByClass(ClusterIterator.class); + if (!job.waitForCompletion(true)) { + throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath); + } + ClusterClassifier.writePolicy(policy, clustersOut); + FileSystem fs = FileSystem.get(outPath.toUri(), conf); + iteration++; + if (isConverged(clustersOut, conf, fs)) { + break; + } + } + Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX); + FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn); + } + + /** + * Return if all of the Clusters in the parts in the filePath have converged or not + * + * @param filePath + * the file path to the single file containing the clusters + * @return true if all Clusters are converged + * @throws IOException + * if there was an IO error + */ + private static boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException { + for (FileStatus part : fs.listStatus(filePath, PathFilters.partFilter())) { + SequenceFileValueIterator<ClusterWritable> iterator = new SequenceFileValueIterator<>( + part.getPath(), true, conf); + while (iterator.hasNext()) { + ClusterWritable value = iterator.next(); + if (!value.getValue().isConverged()) { + Closeables.close(iterator, true); + return false; + } + } + } + return true; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java new file mode 100644 index 0000000..855685f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; +import org.apache.mahout.clustering.Cluster; + +public class ClusterWritable implements Writable { + + private Cluster value; + + public ClusterWritable(Cluster first) { + value = first; + } + + public ClusterWritable() { + } + + public Cluster getValue() { + return value; + } + + public void setValue(Cluster value) { + this.value = value; + } + + @Override + public void write(DataOutput out) throws IOException { + PolymorphicWritable.write(out, value); + } + + @Override + public void readFields(DataInput in) throws IOException { + value = PolymorphicWritable.read(in, Cluster.class); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java new file mode 100644 index 0000000..6e15838 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.math.Vector; + +/** + * A ClusteringPolicy captures the semantics of assignment of points to clusters + * + */ +public interface ClusteringPolicy extends Writable { + + /** + * Classify the data vector given the classifier's models + * + * @param data + * a data Vector + * @param prior + * a prior ClusterClassifier + * @return a Vector of probabilities that the data is described by each of the + * models + */ + Vector classify(Vector data, ClusterClassifier prior); + + /** + * Return a vector of weights for each of the models given those probabilities + * + * @param probabilities + * a Vector of pdfs + * @return a Vector of weights + */ + Vector select(Vector probabilities); + + /** + * Update the policy with the given classifier + * + * @param posterior + * a ClusterClassifier + */ + void update(ClusterClassifier posterior); + + /** + * Close the policy using the classifier's models + * + * @param posterior + * a posterior ClusterClassifier + */ + void close(ClusterClassifier posterior); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java new file mode 100644 index 0000000..f69442d --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; + +public class ClusteringPolicyWritable implements Writable { + + private ClusteringPolicy value; + + public ClusteringPolicyWritable(ClusteringPolicy policy) { + this.value = policy; + } + + public ClusteringPolicyWritable() { + } + + public ClusteringPolicy getValue() { + return value; + } + + public void setValue(ClusteringPolicy value) { + this.value = value; + } + + @Override + public void write(DataOutput out) throws IOException { + PolymorphicWritable.write(out, value); + } + + @Override + public void readFields(DataInput in) throws IOException { + value = PolymorphicWritable.read(in, ClusteringPolicy.class); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java new file mode 100644 index 0000000..f61aa27 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java @@ -0,0 +1,91 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.clustering.AbstractCluster; +import org.apache.mahout.clustering.Model; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +public class DistanceMeasureCluster extends AbstractCluster { + + private DistanceMeasure measure; + + public DistanceMeasureCluster(Vector point, int id, DistanceMeasure measure) { + super(point, id); + this.measure = measure; + } + + public DistanceMeasureCluster() { + } + + @Override + public void configure(Configuration job) { + if (measure != null) { + measure.configure(job); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + String dm = in.readUTF(); + this.measure = ClassUtils.instantiateAs(dm, DistanceMeasure.class); + super.readFields(in); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeUTF(measure.getClass().getName()); + super.write(out); + } + + @Override + public double pdf(VectorWritable vw) { + return 1 / (1 + measure.distance(vw.get(), getCenter())); + } + + @Override + public Model<VectorWritable> sampleFromPosterior() { + return new DistanceMeasureCluster(getCenter(), getId(), measure); + } + + public DistanceMeasure getMeasure() { + return measure; + } + + /** + * @param measure + * the measure to set + */ + public void setMeasure(DistanceMeasure measure) { + this.measure = measure; + } + + @Override + public String getIdentifier() { + return "DMC:" + getId(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java new file mode 100644 index 0000000..b4e41b6 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer; +import org.apache.mahout.clustering.fuzzykmeans.SoftCluster; +import org.apache.mahout.math.Vector; + +/** + * This is a probability-weighted clustering policy, suitable for fuzzy k-means + * clustering + * + */ +public class FuzzyKMeansClusteringPolicy extends AbstractClusteringPolicy { + + private double m = 2; + private double convergenceDelta = 0.05; + + public FuzzyKMeansClusteringPolicy() { + } + + public FuzzyKMeansClusteringPolicy(double m, double convergenceDelta) { + this.m = m; + this.convergenceDelta = convergenceDelta; + } + + @Override + public Vector select(Vector probabilities) { + return probabilities; + } + + @Override + public Vector classify(Vector data, ClusterClassifier prior) { + Collection<SoftCluster> clusters = new ArrayList<>(); + List<Double> distances = new ArrayList<>(); + for (Cluster model : prior.getModels()) { + SoftCluster sc = (SoftCluster) model; + clusters.add(sc); + distances.add(sc.getMeasure().distance(data, sc.getCenter())); + } + FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer(); + fuzzyKMeansClusterer.setM(m); + return fuzzyKMeansClusterer.computePi(clusters, distances); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(m); + out.writeDouble(convergenceDelta); + } + + @Override + public void readFields(DataInput in) throws IOException { + this.m = in.readDouble(); + this.convergenceDelta = in.readDouble(); + } + + @Override + public void close(ClusterClassifier posterior) { + for (Cluster cluster : posterior.getModels()) { + ((org.apache.mahout.clustering.kmeans.Kluster) cluster).calculateConvergence(convergenceDelta); + cluster.computeParameters(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java new file mode 100644 index 0000000..1cc9faf --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; + +/** + * This is a simple maximum likelihood clustering policy, suitable for k-means + * clustering + * + */ +public class KMeansClusteringPolicy extends AbstractClusteringPolicy { + + public KMeansClusteringPolicy() { + } + + public KMeansClusteringPolicy(double convergenceDelta) { + this.convergenceDelta = convergenceDelta; + } + + private double convergenceDelta = 0.001; + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(convergenceDelta); + } + + @Override + public void readFields(DataInput in) throws IOException { + this.convergenceDelta = in.readDouble(); + } + + @Override + public void close(ClusterClassifier posterior) { + boolean allConverged = true; + for (Cluster cluster : posterior.getModels()) { + org.apache.mahout.clustering.kmeans.Kluster kluster = (org.apache.mahout.clustering.kmeans.Kluster) cluster; + boolean converged = kluster.calculateConvergence(convergenceDelta); + allConverged = allConverged && converged; + cluster.computeParameters(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java new file mode 100644 index 0000000..96c4082 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.kernel; + +public interface IKernelProfile { + + /** + * @return the calculated dervative value of the kernel + */ + double calculateDerivativeValue(double distance, double h); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java new file mode 100644 index 0000000..46909bb --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.kernel; + +public class TriangularKernelProfile implements IKernelProfile { + + @Override + public double calculateDerivativeValue(double distance, double h) { + return distance < h ? 1.0 : 0.0; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java new file mode 100644 index 0000000..3b9094e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansDriver.java @@ -0,0 +1,257 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.kmeans; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassificationDriver; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.clustering.iterator.ClusterIterator; +import org.apache.mahout.clustering.iterator.ClusteringPolicy; +import org.apache.mahout.clustering.iterator.KMeansClusteringPolicy; +import org.apache.mahout.clustering.topdown.PathDirectory; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class KMeansDriver extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(KMeansDriver.class); + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new KMeansDriver(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.distanceMeasureOption().create()); + addOption(DefaultOptionCreator + .clustersInOption() + .withDescription( + "The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. " + + "If k is also specified, then a random set of vectors will be selected" + + " and written out to this path first").create()); + addOption(DefaultOptionCreator + .numClustersOption() + .withDescription( + "The k in k-Means. If specified, then a random selection of k Vectors will be chosen" + + " as the Centroid and written to the clusters input path.").create()); + addOption(DefaultOptionCreator.useSetRandomSeedOption().create()); + addOption(DefaultOptionCreator.convergenceOption().create()); + addOption(DefaultOptionCreator.maxIterationsOption().create()); + addOption(DefaultOptionCreator.overwriteOption().create()); + addOption(DefaultOptionCreator.clusteringOption().create()); + addOption(DefaultOptionCreator.methodOption().create()); + addOption(DefaultOptionCreator.outlierThresholdOption().create()); + + if (parseArguments(args) == null) { + return -1; + } + + Path input = getInputPath(); + Path clusters = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION)); + Path output = getOutputPath(); + String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); + if (measureClass == null) { + measureClass = SquaredEuclideanDistanceMeasure.class.getName(); + } + double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION)); + int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class); + + if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) { + int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)); + + Long seed = null; + if (hasOption(DefaultOptionCreator.RANDOM_SEED)) { + seed = Long.parseLong(getOption(DefaultOptionCreator.RANDOM_SEED)); + } + + clusters = RandomSeedGenerator.buildRandom(getConf(), input, clusters, numClusters, measure, seed); + } + boolean runClustering = hasOption(DefaultOptionCreator.CLUSTERING_OPTION); + boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase( + DefaultOptionCreator.SEQUENTIAL_METHOD); + double clusterClassificationThreshold = 0.0; + if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) { + clusterClassificationThreshold = Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD)); + } + run(getConf(), input, clusters, output, convergenceDelta, maxIterations, runClustering, + clusterClassificationThreshold, runSequential); + return 0; + } + + /** + * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to + * cluster the input vectors. + * + * @param input + * the directory pathname for input points + * @param clustersIn + * the directory pathname for initial & computed clusters + * @param output + * the directory pathname for output points + * @param convergenceDelta + * the convergence delta value + * @param maxIterations + * the maximum number of iterations + * @param runClustering + * true if points are to be clustered after iterations are completed + * @param clusterClassificationThreshold + * Is a clustering strictness / outlier removal parameter. Its value should be between 0 and 1. Vectors + * having pdf below this value will not be clustered. + * @param runSequential + * if true execute sequential algorithm + */ + public static void run(Configuration conf, Path input, Path clustersIn, Path output, + double convergenceDelta, int maxIterations, boolean runClustering, double clusterClassificationThreshold, + boolean runSequential) throws IOException, InterruptedException, ClassNotFoundException { + + // iterate until the clusters converge + String delta = Double.toString(convergenceDelta); + if (log.isInfoEnabled()) { + log.info("Input: {} Clusters In: {} Out: {}", input, clustersIn, output); + log.info("convergence: {} max Iterations: {}", convergenceDelta, maxIterations); + } + Path clustersOut = buildClusters(conf, input, clustersIn, output, maxIterations, delta, runSequential); + if (runClustering) { + log.info("Clustering data"); + clusterData(conf, input, clustersOut, output, clusterClassificationThreshold, runSequential); + } + } + + /** + * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to + * cluster the input vectors. + * + * @param input + * the directory pathname for input points + * @param clustersIn + * the directory pathname for initial & computed clusters + * @param output + * the directory pathname for output points + * @param convergenceDelta + * the convergence delta value + * @param maxIterations + * the maximum number of iterations + * @param runClustering + * true if points are to be clustered after iterations are completed + * @param clusterClassificationThreshold + * Is a clustering strictness / outlier removal parameter. Its value should be between 0 and 1. Vectors + * having pdf below this value will not be clustered. + * @param runSequential + * if true execute sequential algorithm + */ + public static void run(Path input, Path clustersIn, Path output, double convergenceDelta, + int maxIterations, boolean runClustering, double clusterClassificationThreshold, boolean runSequential) + throws IOException, InterruptedException, ClassNotFoundException { + run(new Configuration(), input, clustersIn, output, convergenceDelta, maxIterations, runClustering, + clusterClassificationThreshold, runSequential); + } + + /** + * Iterate over the input vectors to produce cluster directories for each iteration + * + * + * @param conf + * the Configuration to use + * @param input + * the directory pathname for input points + * @param clustersIn + * the directory pathname for initial & computed clusters + * @param output + * the directory pathname for output points + * @param maxIterations + * the maximum number of iterations + * @param delta + * the convergence delta value + * @param runSequential + * if true execute sequential algorithm + * + * @return the Path of the final clusters directory + */ + public static Path buildClusters(Configuration conf, Path input, Path clustersIn, Path output, + int maxIterations, String delta, boolean runSequential) throws IOException, + InterruptedException, ClassNotFoundException { + + double convergenceDelta = Double.parseDouble(delta); + List<Cluster> clusters = new ArrayList<>(); + KMeansUtil.configureWithClusterInfo(conf, clustersIn, clusters); + + if (clusters.isEmpty()) { + throw new IllegalStateException("No input clusters found in " + clustersIn + ". Check your -c argument."); + } + + Path priorClustersPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR); + ClusteringPolicy policy = new KMeansClusteringPolicy(convergenceDelta); + ClusterClassifier prior = new ClusterClassifier(clusters, policy); + prior.writeToSeqFiles(priorClustersPath); + + if (runSequential) { + ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations); + } else { + ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations); + } + return output; + } + + /** + * Run the job using supplied arguments + * + * @param input + * the directory pathname for input points + * @param clustersIn + * the directory pathname for input clusters + * @param output + * the directory pathname for output points + * @param clusterClassificationThreshold + * Is a clustering strictness / outlier removal parameter. Its value should be between 0 and 1. Vectors + * having pdf below this value will not be clustered. + * @param runSequential + * if true execute sequential algorithm + */ + public static void clusterData(Configuration conf, Path input, Path clustersIn, Path output, + double clusterClassificationThreshold, boolean runSequential) throws IOException, InterruptedException, + ClassNotFoundException { + + if (log.isInfoEnabled()) { + log.info("Running Clustering"); + log.info("Input: {} Clusters In: {} Out: {}", input, clustersIn, output); + } + ClusterClassifier.writePolicy(new KMeansClusteringPolicy(), clustersIn); + ClusterClassificationDriver.run(conf, input, output, new Path(output, PathDirectory.CLUSTERED_POINTS_DIRECTORY), + clusterClassificationThreshold, true, runSequential); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java new file mode 100644 index 0000000..3365f70 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/KMeansUtil.java @@ -0,0 +1,74 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.kmeans; + +import java.util.Collection; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.canopy.Canopy; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +final class KMeansUtil { + + private static final Logger log = LoggerFactory.getLogger(KMeansUtil.class); + + private KMeansUtil() {} + + /** + * Create a list of Klusters from whatever Cluster type is passed in as the prior + * + * @param conf + * the Configuration + * @param clusterPath + * the path to the prior Clusters + * @param clusters + * a List<Cluster> to put values into + */ + public static void configureWithClusterInfo(Configuration conf, Path clusterPath, Collection<Cluster> clusters) { + for (Writable value : new SequenceFileDirValueIterable<>(clusterPath, PathType.LIST, + PathFilters.partFilter(), conf)) { + Class<? extends Writable> valueClass = value.getClass(); + if (valueClass.equals(ClusterWritable.class)) { + ClusterWritable clusterWritable = (ClusterWritable) value; + value = clusterWritable.getValue(); + valueClass = value.getClass(); + } + log.debug("Read 1 Cluster from {}", clusterPath); + + if (valueClass.equals(Kluster.class)) { + // get the cluster info + clusters.add((Kluster) value); + } else if (valueClass.equals(Canopy.class)) { + // get the cluster info + Canopy canopy = (Canopy) value; + clusters.add(new Kluster(canopy.getCenter(), canopy.getId(), canopy.getMeasure())); + } else { + throw new IllegalStateException("Bad value class: " + valueClass); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java new file mode 100644 index 0000000..15daec5 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/Kluster.java @@ -0,0 +1,117 @@ +/* Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.kmeans; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.mahout.clustering.iterator.DistanceMeasureCluster; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; + +public class Kluster extends DistanceMeasureCluster { + + /** Has the centroid converged with the center? */ + private boolean converged; + + /** For (de)serialization as a Writable */ + public Kluster() { + } + + /** + * Construct a new cluster with the given point as its center + * + * @param center + * the Vector center + * @param clusterId + * the int cluster id + * @param measure + * a DistanceMeasure + */ + public Kluster(Vector center, int clusterId, DistanceMeasure measure) { + super(center, clusterId, measure); + } + + /** + * Format the cluster for output + * + * @param cluster + * the Cluster + * @return the String representation of the Cluster + */ + public static String formatCluster(Kluster cluster) { + return cluster.getIdentifier() + ": " + cluster.computeCentroid().asFormatString(); + } + + public String asFormatString() { + return formatCluster(this); + } + + @Override + public void write(DataOutput out) throws IOException { + super.write(out); + out.writeBoolean(converged); + } + + @Override + public void readFields(DataInput in) throws IOException { + super.readFields(in); + this.converged = in.readBoolean(); + } + + @Override + public String toString() { + return asFormatString(null); + } + + @Override + public String getIdentifier() { + return (converged ? "VL-" : "CL-") + getId(); + } + + /** + * Return if the cluster is converged by comparing its center and centroid. + * + * @param measure + * The distance measure to use for cluster-point comparisons. + * @param convergenceDelta + * the convergence delta to use for stopping. + * @return if the cluster is converged + */ + public boolean computeConvergence(DistanceMeasure measure, double convergenceDelta) { + Vector centroid = computeCentroid(); + converged = measure.distance(centroid.getLengthSquared(), centroid, getCenter()) <= convergenceDelta; + return converged; + } + + @Override + public boolean isConverged() { + return converged; + } + + protected void setConverged(boolean converged) { + this.converged = converged; + } + + public boolean calculateConvergence(double convergenceDelta) { + Vector centroid = computeCentroid(); + converged = getMeasure().distance(centroid.getLengthSquared(), centroid, getCenter()) <= convergenceDelta; + return converged; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java new file mode 100644 index 0000000..fbbabc5 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/RandomSeedGenerator.java @@ -0,0 +1,136 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.kmeans; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, randomly select k vectors and + * write them to the output file as a {@link org.apache.mahout.clustering.kmeans.Kluster} representing the + * initial centroid to use. + * + * This implementation uses reservoir sampling as described in http://en.wikipedia.org/wiki/Reservoir_sampling + */ +public final class RandomSeedGenerator { + + private static final Logger log = LoggerFactory.getLogger(RandomSeedGenerator.class); + + public static final String K = "k"; + + private RandomSeedGenerator() {} + + public static Path buildRandom(Configuration conf, Path input, Path output, int k, DistanceMeasure measure) + throws IOException { + return buildRandom(conf, input, output, k, measure, null); + } + + public static Path buildRandom(Configuration conf, + Path input, + Path output, + int k, + DistanceMeasure measure, + Long seed) throws IOException { + + Preconditions.checkArgument(k > 0, "Must be: k > 0, but k = " + k); + // delete the output directory + FileSystem fs = FileSystem.get(output.toUri(), conf); + HadoopUtil.delete(conf, output); + Path outFile = new Path(output, "part-randomSeed"); + boolean newFile = fs.createNewFile(outFile); + if (newFile) { + Path inputPathPattern; + + if (fs.getFileStatus(input).isDir()) { + inputPathPattern = new Path(input, "*"); + } else { + inputPathPattern = input; + } + + FileStatus[] inputFiles = fs.globStatus(inputPathPattern, PathFilters.logsCRCFilter()); + + Random random = (seed != null) ? RandomUtils.getRandom(seed) : RandomUtils.getRandom(); + + List<Text> chosenTexts = new ArrayList<>(k); + List<ClusterWritable> chosenClusters = new ArrayList<>(k); + int nextClusterId = 0; + + int index = 0; + for (FileStatus fileStatus : inputFiles) { + if (!fileStatus.isDir()) { + for (Pair<Writable, VectorWritable> record + : new SequenceFileIterable<Writable, VectorWritable>(fileStatus.getPath(), true, conf)) { + Writable key = record.getFirst(); + VectorWritable value = record.getSecond(); + Kluster newCluster = new Kluster(value.get(), nextClusterId++, measure); + newCluster.observe(value.get(), 1); + Text newText = new Text(key.toString()); + int currentSize = chosenTexts.size(); + if (currentSize < k) { + chosenTexts.add(newText); + ClusterWritable clusterWritable = new ClusterWritable(); + clusterWritable.setValue(newCluster); + chosenClusters.add(clusterWritable); + } else { + int j = random.nextInt(index); + if (j < k) { + chosenTexts.set(j, newText); + ClusterWritable clusterWritable = new ClusterWritable(); + clusterWritable.setValue(newCluster); + chosenClusters.set(j, clusterWritable); + } + } + index++; + } + } + } + + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs, conf, outFile, Text.class, ClusterWritable.class)){ + for (int i = 0; i < chosenTexts.size(); i++) { + writer.append(chosenTexts.get(i), chosenClusters.get(i)); + } + log.info("Wrote {} Klusters to {}", k, outFile); + } + } + + return outFile; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java new file mode 100644 index 0000000..d6921b6 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/kmeans/package-info.java @@ -0,0 +1,5 @@ +/** + * This package provides an implementation of the <a href="http://en.wikipedia.org/wiki/Kmeans">k-means</a> clustering + * algorithm. + */ +package org.apache.mahout.clustering.kmeans; http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java new file mode 100644 index 0000000..46fcc7f --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +public class CVB0DocInferenceMapper extends CachingCVB0Mapper { + + private final VectorWritable topics = new VectorWritable(); + + @Override + public void map(IntWritable docId, VectorWritable doc, Context context) + throws IOException, InterruptedException { + int numTopics = getNumTopics(); + Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); + Matrix docModel = new SparseRowMatrix(numTopics, doc.get().size()); + int maxIters = getMaxIters(); + ModelTrainer modelTrainer = getModelTrainer(); + for (int i = 0; i < maxIters; i++) { + modelTrainer.getReadModel().trainDocTopicModel(doc.get(), docTopics, docModel); + } + topics.set(docTopics); + context.write(docId, topics); + } + + @Override + protected void cleanup(Context context) { + getModelTrainer().stop(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java new file mode 100644 index 0000000..31c0d60 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java @@ -0,0 +1,536 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import java.io.IOException; +import java.net.URI; +import java.util.ArrayList; +import java.util.List; + +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.mapreduce.VectorSumReducer; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * See {@link CachingCVB0Mapper} for more details on scalability and room for improvement. + * To try out this LDA implementation without using Hadoop, check out + * {@link InMemoryCollapsedVariationalBayes0}. If you want to do training directly in java code + * with your own main(), then look to {@link ModelTrainer} and {@link TopicModel}. + * + * Usage: {@code ./bin/mahout cvb <i>options</i>} + * <p> + * Valid options include: + * <dl> + * <dt>{@code --input path}</td> + * <dd>Input path for {@code SequenceFile<IntWritable, VectorWritable>} document vectors. See + * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} + * for details on how to generate this input format.</dd> + * <dt>{@code --dictionary path}</dt> + * <dd>Path to dictionary file(s) generated during construction of input document vectors (glob + * expression supported). If set, this data is scanned to determine an appropriate value for option + * {@code --num_terms}.</dd> + * <dt>{@code --output path}</dt> + * <dd>Output path for topic-term distributions.</dd> + * <dt>{@code --doc_topic_output path}</dt> + * <dd>Output path for doc-topic distributions.</dd> + * <dt>{@code --num_topics k}</dt> + * <dd>Number of latent topics.</dd> + * <dt>{@code --num_terms nt}</dt> + * <dd>Number of unique features defined by input document vectors. If option {@code --dictionary} + * is defined and this option is unspecified, term count is calculated from dictionary.</dd> + * <dt>{@code --topic_model_temp_dir path}</dt> + * <dd>Path in which to store model state after each iteration.</dd> + * <dt>{@code --maxIter i}</dt> + * <dd>Maximum number of iterations to perform. If this value is less than or equal to the number of + * iteration states found beneath the path specified by option {@code --topic_model_temp_dir}, no + * further iterations are performed. Instead, output topic-term and doc-topic distributions are + * generated using data from the specified iteration.</dd> + * <dt>{@code --max_doc_topic_iters i}</dt> + * <dd>Maximum number of iterations per doc for p(topic|doc) learning. Defaults to {@code 10}.</dd> + * <dt>{@code --doc_topic_smoothing a}</dt> + * <dd>Smoothing for doc-topic distribution. Defaults to {@code 0.0001}.</dd> + * <dt>{@code --term_topic_smoothing e}</dt> + * <dd>Smoothing for topic-term distribution. Defaults to {@code 0.0001}.</dd> + * <dt>{@code --random_seed seed}</dt> + * <dd>Integer seed for random number generation.</dd> + * <dt>{@code --test_set_percentage p}</dt> + * <dd>Fraction of data to hold out for testing. Defaults to {@code 0.0}.</dd> + * <dt>{@code --iteration_block_size block}</dt> + * <dd>Number of iterations between perplexity checks. Defaults to {@code 10}. This option is + * ignored unless option {@code --test_set_percentage} is greater than zero.</dd> + * </dl> + */ +public class CVB0Driver extends AbstractJob { + private static final Logger log = LoggerFactory.getLogger(CVB0Driver.class); + + public static final String NUM_TOPICS = "num_topics"; + public static final String NUM_TERMS = "num_terms"; + public static final String DOC_TOPIC_SMOOTHING = "doc_topic_smoothing"; + public static final String TERM_TOPIC_SMOOTHING = "term_topic_smoothing"; + public static final String DICTIONARY = "dictionary"; + public static final String DOC_TOPIC_OUTPUT = "doc_topic_output"; + public static final String MODEL_TEMP_DIR = "topic_model_temp_dir"; + public static final String ITERATION_BLOCK_SIZE = "iteration_block_size"; + public static final String RANDOM_SEED = "random_seed"; + public static final String TEST_SET_FRACTION = "test_set_fraction"; + public static final String NUM_TRAIN_THREADS = "num_train_threads"; + public static final String NUM_UPDATE_THREADS = "num_update_threads"; + public static final String MAX_ITERATIONS_PER_DOC = "max_doc_topic_iters"; + public static final String MODEL_WEIGHT = "prev_iter_mult"; + public static final String NUM_REDUCE_TASKS = "num_reduce_tasks"; + public static final String BACKFILL_PERPLEXITY = "backfill_perplexity"; + private static final String MODEL_PATHS = "mahout.lda.cvb.modelPath"; + + private static final double DEFAULT_CONVERGENCE_DELTA = 0; + private static final double DEFAULT_DOC_TOPIC_SMOOTHING = 0.0001; + private static final double DEFAULT_TERM_TOPIC_SMOOTHING = 0.0001; + private static final int DEFAULT_ITERATION_BLOCK_SIZE = 10; + private static final double DEFAULT_TEST_SET_FRACTION = 0; + private static final int DEFAULT_NUM_TRAIN_THREADS = 4; + private static final int DEFAULT_NUM_UPDATE_THREADS = 1; + private static final int DEFAULT_MAX_ITERATIONS_PER_DOC = 10; + private static final int DEFAULT_NUM_REDUCE_TASKS = 10; + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.maxIterationsOption().create()); + addOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION, "cd", "The convergence delta value", + String.valueOf(DEFAULT_CONVERGENCE_DELTA)); + addOption(DefaultOptionCreator.overwriteOption().create()); + + addOption(NUM_TOPICS, "k", "Number of topics to learn", true); + addOption(NUM_TERMS, "nt", "Vocabulary size", false); + addOption(DOC_TOPIC_SMOOTHING, "a", "Smoothing for document/topic distribution", + String.valueOf(DEFAULT_DOC_TOPIC_SMOOTHING)); + addOption(TERM_TOPIC_SMOOTHING, "e", "Smoothing for topic/term distribution", + String.valueOf(DEFAULT_TERM_TOPIC_SMOOTHING)); + addOption(DICTIONARY, "dict", "Path to term-dictionary file(s) (glob expression supported)", false); + addOption(DOC_TOPIC_OUTPUT, "dt", "Output path for the training doc/topic distribution", false); + addOption(MODEL_TEMP_DIR, "mt", "Path to intermediate model path (useful for restarting)", false); + addOption(ITERATION_BLOCK_SIZE, "block", "Number of iterations per perplexity check", + String.valueOf(DEFAULT_ITERATION_BLOCK_SIZE)); + addOption(RANDOM_SEED, "seed", "Random seed", false); + addOption(TEST_SET_FRACTION, "tf", "Fraction of data to hold out for testing", + String.valueOf(DEFAULT_TEST_SET_FRACTION)); + addOption(NUM_TRAIN_THREADS, "ntt", "number of threads per mapper to train with", + String.valueOf(DEFAULT_NUM_TRAIN_THREADS)); + addOption(NUM_UPDATE_THREADS, "nut", "number of threads per mapper to update the model with", + String.valueOf(DEFAULT_NUM_UPDATE_THREADS)); + addOption(MAX_ITERATIONS_PER_DOC, "mipd", "max number of iterations per doc for p(topic|doc) learning", + String.valueOf(DEFAULT_MAX_ITERATIONS_PER_DOC)); + addOption(NUM_REDUCE_TASKS, null, "number of reducers to use during model estimation", + String.valueOf(DEFAULT_NUM_REDUCE_TASKS)); + addOption(buildOption(BACKFILL_PERPLEXITY, null, "enable backfilling of missing perplexity values", false, false, + null)); + + if (parseArguments(args) == null) { + return -1; + } + + int numTopics = Integer.parseInt(getOption(NUM_TOPICS)); + Path inputPath = getInputPath(); + Path topicModelOutputPath = getOutputPath(); + int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); + int iterationBlockSize = Integer.parseInt(getOption(ITERATION_BLOCK_SIZE)); + double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION)); + double alpha = Double.parseDouble(getOption(DOC_TOPIC_SMOOTHING)); + double eta = Double.parseDouble(getOption(TERM_TOPIC_SMOOTHING)); + int numTrainThreads = Integer.parseInt(getOption(NUM_TRAIN_THREADS)); + int numUpdateThreads = Integer.parseInt(getOption(NUM_UPDATE_THREADS)); + int maxItersPerDoc = Integer.parseInt(getOption(MAX_ITERATIONS_PER_DOC)); + Path dictionaryPath = hasOption(DICTIONARY) ? new Path(getOption(DICTIONARY)) : null; + int numTerms = hasOption(NUM_TERMS) + ? Integer.parseInt(getOption(NUM_TERMS)) + : getNumTerms(getConf(), dictionaryPath); + Path docTopicOutputPath = hasOption(DOC_TOPIC_OUTPUT) ? new Path(getOption(DOC_TOPIC_OUTPUT)) : null; + Path modelTempPath = hasOption(MODEL_TEMP_DIR) + ? new Path(getOption(MODEL_TEMP_DIR)) + : getTempPath("topicModelState"); + long seed = hasOption(RANDOM_SEED) + ? Long.parseLong(getOption(RANDOM_SEED)) + : System.nanoTime() % 10000; + float testFraction = hasOption(TEST_SET_FRACTION) + ? Float.parseFloat(getOption(TEST_SET_FRACTION)) + : 0.0f; + int numReduceTasks = Integer.parseInt(getOption(NUM_REDUCE_TASKS)); + boolean backfillPerplexity = hasOption(BACKFILL_PERPLEXITY); + + return run(getConf(), inputPath, topicModelOutputPath, numTopics, numTerms, alpha, eta, + maxIterations, iterationBlockSize, convergenceDelta, dictionaryPath, docTopicOutputPath, + modelTempPath, seed, testFraction, numTrainThreads, numUpdateThreads, maxItersPerDoc, + numReduceTasks, backfillPerplexity); + } + + private static int getNumTerms(Configuration conf, Path dictionaryPath) throws IOException { + FileSystem fs = dictionaryPath.getFileSystem(conf); + Text key = new Text(); + IntWritable value = new IntWritable(); + int maxTermId = -1; + for (FileStatus stat : fs.globStatus(dictionaryPath)) { + SequenceFile.Reader reader = new SequenceFile.Reader(fs, stat.getPath(), conf); + while (reader.next(key, value)) { + maxTermId = Math.max(maxTermId, value.get()); + } + } + return maxTermId + 1; + } + + public int run(Configuration conf, + Path inputPath, + Path topicModelOutputPath, + int numTopics, + int numTerms, + double alpha, + double eta, + int maxIterations, + int iterationBlockSize, + double convergenceDelta, + Path dictionaryPath, + Path docTopicOutputPath, + Path topicModelStateTempPath, + long randomSeed, + float testFraction, + int numTrainThreads, + int numUpdateThreads, + int maxItersPerDoc, + int numReduceTasks, + boolean backfillPerplexity) + throws ClassNotFoundException, IOException, InterruptedException { + + setConf(conf); + + // verify arguments + Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0, + "Expected 'testFraction' value in range [0, 1] but found value '%s'", testFraction); + Preconditions.checkArgument(!backfillPerplexity || testFraction > 0.0, + "Expected 'testFraction' value in range (0, 1] but found value '%s'", testFraction); + + String infoString = "Will run Collapsed Variational Bayes (0th-derivative approximation) " + + "learning for LDA on {} (numTerms: {}), finding {}-topics, with document/topic prior {}, " + + "topic/term prior {}. Maximum iterations to run will be {}, unless the change in " + + "perplexity is less than {}. Topic model output (p(term|topic) for each topic) will be " + + "stored {}. Random initialization seed is {}, holding out {} of the data for perplexity " + + "check\n"; + log.info(infoString, inputPath, numTerms, numTopics, alpha, eta, maxIterations, + convergenceDelta, topicModelOutputPath, randomSeed, testFraction); + infoString = dictionaryPath == null + ? "" : "Dictionary to be used located " + dictionaryPath.toString() + '\n'; + infoString += docTopicOutputPath == null + ? "" : "p(topic|docId) will be stored " + docTopicOutputPath.toString() + '\n'; + log.info(infoString); + + FileSystem fs = FileSystem.get(topicModelStateTempPath.toUri(), conf); + int iterationNumber = getCurrentIterationNumber(conf, topicModelStateTempPath, maxIterations); + log.info("Current iteration number: {}", iterationNumber); + + conf.set(NUM_TOPICS, String.valueOf(numTopics)); + conf.set(NUM_TERMS, String.valueOf(numTerms)); + conf.set(DOC_TOPIC_SMOOTHING, String.valueOf(alpha)); + conf.set(TERM_TOPIC_SMOOTHING, String.valueOf(eta)); + conf.set(RANDOM_SEED, String.valueOf(randomSeed)); + conf.set(NUM_TRAIN_THREADS, String.valueOf(numTrainThreads)); + conf.set(NUM_UPDATE_THREADS, String.valueOf(numUpdateThreads)); + conf.set(MAX_ITERATIONS_PER_DOC, String.valueOf(maxItersPerDoc)); + conf.set(MODEL_WEIGHT, "1"); // TODO + conf.set(TEST_SET_FRACTION, String.valueOf(testFraction)); + + List<Double> perplexities = new ArrayList<>(); + for (int i = 1; i <= iterationNumber; i++) { + // form path to model + Path modelPath = modelPath(topicModelStateTempPath, i); + + // read perplexity + double perplexity = readPerplexity(conf, topicModelStateTempPath, i); + if (Double.isNaN(perplexity)) { + if (!(backfillPerplexity && i % iterationBlockSize == 0)) { + continue; + } + log.info("Backfilling perplexity at iteration {}", i); + if (!fs.exists(modelPath)) { + log.error("Model path '{}' does not exist; Skipping iteration {} perplexity calculation", + modelPath.toString(), i); + continue; + } + perplexity = calculatePerplexity(conf, inputPath, modelPath, i); + } + + // register and log perplexity + perplexities.add(perplexity); + log.info("Perplexity at iteration {} = {}", i, perplexity); + } + + long startTime = System.currentTimeMillis(); + while (iterationNumber < maxIterations) { + // test convergence + if (convergenceDelta > 0.0) { + double delta = rateOfChange(perplexities); + if (delta < convergenceDelta) { + log.info("Convergence achieved at iteration {} with perplexity {} and delta {}", + iterationNumber, perplexities.get(perplexities.size() - 1), delta); + break; + } + } + + // update model + iterationNumber++; + log.info("About to run iteration {} of {}", iterationNumber, maxIterations); + Path modelInputPath = modelPath(topicModelStateTempPath, iterationNumber - 1); + Path modelOutputPath = modelPath(topicModelStateTempPath, iterationNumber); + runIteration(conf, inputPath, modelInputPath, modelOutputPath, iterationNumber, + maxIterations, numReduceTasks); + + // calculate perplexity + if (testFraction > 0 && iterationNumber % iterationBlockSize == 0) { + perplexities.add(calculatePerplexity(conf, inputPath, modelOutputPath, iterationNumber)); + log.info("Current perplexity = {}", perplexities.get(perplexities.size() - 1)); + log.info("(p_{} - p_{}) / p_0 = {}; target = {}", iterationNumber, iterationNumber - iterationBlockSize, + rateOfChange(perplexities), convergenceDelta); + } + } + log.info("Completed {} iterations in {} seconds", iterationNumber, + (System.currentTimeMillis() - startTime) / 1000); + log.info("Perplexities: ({})", Joiner.on(", ").join(perplexities)); + + // write final topic-term and doc-topic distributions + Path finalIterationData = modelPath(topicModelStateTempPath, iterationNumber); + Job topicModelOutputJob = topicModelOutputPath != null + ? writeTopicModel(conf, finalIterationData, topicModelOutputPath) + : null; + Job docInferenceJob = docTopicOutputPath != null + ? writeDocTopicInference(conf, inputPath, finalIterationData, docTopicOutputPath) + : null; + if (topicModelOutputJob != null && !topicModelOutputJob.waitForCompletion(true)) { + return -1; + } + if (docInferenceJob != null && !docInferenceJob.waitForCompletion(true)) { + return -1; + } + return 0; + } + + private static double rateOfChange(List<Double> perplexities) { + int sz = perplexities.size(); + if (sz < 2) { + return Double.MAX_VALUE; + } + return Math.abs(perplexities.get(sz - 1) - perplexities.get(sz - 2)) / perplexities.get(0); + } + + private double calculatePerplexity(Configuration conf, Path corpusPath, Path modelPath, int iteration) + throws IOException, ClassNotFoundException, InterruptedException { + String jobName = "Calculating perplexity for " + modelPath; + log.info("About to run: {}", jobName); + + Path outputPath = perplexityPath(modelPath.getParent(), iteration); + Job job = prepareJob(corpusPath, outputPath, CachingCVB0PerplexityMapper.class, DoubleWritable.class, + DoubleWritable.class, DualDoubleSumReducer.class, DoubleWritable.class, DoubleWritable.class); + + job.setJobName(jobName); + job.setCombinerClass(DualDoubleSumReducer.class); + job.setNumReduceTasks(1); + setModelPaths(job, modelPath); + HadoopUtil.delete(conf, outputPath); + if (!job.waitForCompletion(true)) { + throw new InterruptedException("Failed to calculate perplexity for: " + modelPath); + } + return readPerplexity(conf, modelPath.getParent(), iteration); + } + + /** + * Sums keys and values independently. + */ + public static class DualDoubleSumReducer extends + Reducer<DoubleWritable, DoubleWritable, DoubleWritable, DoubleWritable> { + private final DoubleWritable outKey = new DoubleWritable(); + private final DoubleWritable outValue = new DoubleWritable(); + + @Override + public void run(Context context) throws IOException, + InterruptedException { + double keySum = 0.0; + double valueSum = 0.0; + while (context.nextKey()) { + keySum += context.getCurrentKey().get(); + for (DoubleWritable value : context.getValues()) { + valueSum += value.get(); + } + } + outKey.set(keySum); + outValue.set(valueSum); + context.write(outKey, outValue); + } + } + + /** + * @param topicModelStateTemp + * @param iteration + * @return {@code double[2]} where first value is perplexity and second is model weight of those + * documents sampled during perplexity computation, or {@code null} if no perplexity data + * exists for the given iteration. + * @throws IOException + */ + public static double readPerplexity(Configuration conf, Path topicModelStateTemp, int iteration) + throws IOException { + Path perplexityPath = perplexityPath(topicModelStateTemp, iteration); + FileSystem fs = FileSystem.get(perplexityPath.toUri(), conf); + if (!fs.exists(perplexityPath)) { + log.warn("Perplexity path {} does not exist, returning NaN", perplexityPath); + return Double.NaN; + } + double perplexity = 0; + double modelWeight = 0; + long n = 0; + for (Pair<DoubleWritable, DoubleWritable> pair : new SequenceFileDirIterable<DoubleWritable, DoubleWritable>( + perplexityPath, PathType.LIST, PathFilters.partFilter(), null, true, conf)) { + modelWeight += pair.getFirst().get(); + perplexity += pair.getSecond().get(); + n++; + } + log.info("Read {} entries with total perplexity {} and model weight {}", n, + perplexity, modelWeight); + return perplexity / modelWeight; + } + + private Job writeTopicModel(Configuration conf, Path modelInput, Path output) + throws IOException, InterruptedException, ClassNotFoundException { + String jobName = String.format("Writing final topic/term distributions from %s to %s", modelInput, output); + log.info("About to run: {}", jobName); + + Job job = prepareJob(modelInput, output, SequenceFileInputFormat.class, CVB0TopicTermVectorNormalizerMapper.class, + IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName); + job.submit(); + return job; + } + + private Job writeDocTopicInference(Configuration conf, Path corpus, Path modelInput, Path output) + throws IOException, ClassNotFoundException, InterruptedException { + String jobName = String.format("Writing final document/topic inference from %s to %s", corpus, output); + log.info("About to run: {}", jobName); + + Job job = prepareJob(corpus, output, SequenceFileInputFormat.class, CVB0DocInferenceMapper.class, + IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName); + + FileSystem fs = FileSystem.get(corpus.toUri(), conf); + if (modelInput != null && fs.exists(modelInput)) { + FileStatus[] statuses = fs.listStatus(modelInput, PathFilters.partFilter()); + URI[] modelUris = new URI[statuses.length]; + for (int i = 0; i < statuses.length; i++) { + modelUris[i] = statuses[i].getPath().toUri(); + } + DistributedCache.setCacheFiles(modelUris, conf); + setModelPaths(job, modelInput); + } + job.submit(); + return job; + } + + public static Path modelPath(Path topicModelStateTempPath, int iterationNumber) { + return new Path(topicModelStateTempPath, "model-" + iterationNumber); + } + + public static Path perplexityPath(Path topicModelStateTempPath, int iterationNumber) { + return new Path(topicModelStateTempPath, "perplexity-" + iterationNumber); + } + + private static int getCurrentIterationNumber(Configuration config, Path modelTempDir, int maxIterations) + throws IOException { + FileSystem fs = FileSystem.get(modelTempDir.toUri(), config); + int iterationNumber = 1; + Path iterationPath = modelPath(modelTempDir, iterationNumber); + while (fs.exists(iterationPath) && iterationNumber <= maxIterations) { + log.info("Found previous state: {}", iterationPath); + iterationNumber++; + iterationPath = modelPath(modelTempDir, iterationNumber); + } + return iterationNumber - 1; + } + + public void runIteration(Configuration conf, Path corpusInput, Path modelInput, Path modelOutput, + int iterationNumber, int maxIterations, int numReduceTasks) + throws IOException, ClassNotFoundException, InterruptedException { + String jobName = String.format("Iteration %d of %d, input path: %s", + iterationNumber, maxIterations, modelInput); + log.info("About to run: {}", jobName); + Job job = prepareJob(corpusInput, modelOutput, CachingCVB0Mapper.class, IntWritable.class, VectorWritable.class, + VectorSumReducer.class, IntWritable.class, VectorWritable.class); + job.setCombinerClass(VectorSumReducer.class); + job.setNumReduceTasks(numReduceTasks); + job.setJobName(jobName); + setModelPaths(job, modelInput); + HadoopUtil.delete(conf, modelOutput); + if (!job.waitForCompletion(true)) { + throw new InterruptedException(String.format("Failed to complete iteration %d stage 1", + iterationNumber)); + } + } + + private static void setModelPaths(Job job, Path modelPath) throws IOException { + Configuration conf = job.getConfiguration(); + if (modelPath == null || !FileSystem.get(modelPath.toUri(), conf).exists(modelPath)) { + return; + } + FileStatus[] statuses = FileSystem.get(modelPath.toUri(), conf).listStatus(modelPath, PathFilters.partFilter()); + Preconditions.checkState(statuses.length > 0, "No part files found in model path '%s'", modelPath.toString()); + String[] modelPaths = new String[statuses.length]; + for (int i = 0; i < statuses.length; i++) { + modelPaths[i] = statuses[i].getPath().toUri().toString(); + } + conf.setStrings(MODEL_PATHS, modelPaths); + } + + public static Path[] getModelPaths(Configuration conf) { + String[] modelPathNames = conf.getStrings(MODEL_PATHS); + if (modelPathNames == null || modelPathNames.length == 0) { + return null; + } + Path[] modelPaths = new Path[modelPathNames.length]; + for (int i = 0; i < modelPathNames.length; i++) { + modelPaths[i] = new Path(modelPathNames[i]); + } + return modelPaths; + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new CVB0Driver(), args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java new file mode 100644 index 0000000..1253942 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +import java.io.IOException; + +/** + * Performs L1 normalization of input vectors. + */ +public class CVB0TopicTermVectorNormalizerMapper extends + Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { + + @Override + protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, + InterruptedException { + value.get().assign(Functions.div(value.get().norm(1.0))); + context.write(key, value); + } +}
