http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java new file mode 100644 index 0000000..96f36d4 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Run ensemble learning via loading the {@link ModelTrainer} with two {@link TopicModel} instances: + * one from the previous iteration, the other empty. Inference is done on the first, and the + * learning updates are stored in the second, and only emitted at cleanup(). + * <p/> + * In terms of obvious performance improvements still available, the memory footprint in this + * Mapper could be dropped by half if we accumulated model updates onto the model we're using + * for inference, which might also speed up convergence, as we'd be able to take advantage of + * learning <em>during</em> iteration, not just after each one is done. Most likely we don't + * really need to accumulate double values in the model either, floats would most likely be + * sufficient. Between these two, we could squeeze another factor of 4 in memory efficiency. + * <p/> + * In terms of CPU, we're re-learning the p(topic|doc) distribution on every iteration, starting + * from scratch. This is usually only 10 fixed-point iterations per doc, but that's 10x more than + * only 1. To avoid having to do this, we would need to do a map-side join of the unchanging + * corpus with the continually-improving p(topic|doc) matrix, and then emit multiple outputs + * from the mappers to make sure we can do the reduce model averaging as well. Tricky, but + * possibly worth it. + * <p/> + * {@link ModelTrainer} already takes advantage (in maybe the not-nice way) of multi-core + * availability by doing multithreaded learning, see that class for details. + */ +public class CachingCVB0Mapper + extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { + + private static final Logger log = LoggerFactory.getLogger(CachingCVB0Mapper.class); + + private ModelTrainer modelTrainer; + private TopicModel readModel; + private TopicModel writeModel; + private int maxIters; + private int numTopics; + + protected ModelTrainer getModelTrainer() { + return modelTrainer; + } + + protected int getMaxIters() { + return maxIters; + } + + protected int getNumTopics() { + return numTopics; + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + log.info("Retrieving configuration"); + Configuration conf = context.getConfiguration(); + float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); + float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); + long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); + numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); + int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); + int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); + int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); + maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); + float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); + + log.info("Initializing read model"); + Path[] modelPaths = CVB0Driver.getModelPaths(conf); + if (modelPaths != null && modelPaths.length > 0) { + readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); + } else { + log.info("No model files found"); + readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, + numTrainThreads, modelWeight); + } + + log.info("Initializing write model"); + writeModel = modelWeight == 1 + ? new TopicModel(numTopics, numTerms, eta, alpha, null, numUpdateThreads) + : readModel; + + log.info("Initializing model trainer"); + modelTrainer = new ModelTrainer(readModel, writeModel, numTrainThreads, numTopics, numTerms); + modelTrainer.start(); + } + + @Override + public void map(IntWritable docId, VectorWritable document, Context context) + throws IOException, InterruptedException { + /* where to get docTopics? */ + Vector topicVector = new DenseVector(numTopics).assign(1.0 / numTopics); + modelTrainer.train(document.get(), topicVector, true, maxIters); + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + log.info("Stopping model trainer"); + modelTrainer.stop(); + + log.info("Writing model"); + TopicModel readFrom = modelTrainer.getReadModel(); + for (MatrixSlice topic : readFrom) { + context.write(new IntWritable(topic.index()), new VectorWritable(topic.vector())); + } + readModel.stop(); + writeModel.stop(); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java new file mode 100644 index 0000000..da77baf --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MemoryUtil; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Random; + +public class CachingCVB0PerplexityMapper extends + Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable> { + /** + * Hadoop counters for {@link CachingCVB0PerplexityMapper}, to aid in debugging. + */ + public enum Counters { + SAMPLED_DOCUMENTS + } + + private static final Logger log = LoggerFactory.getLogger(CachingCVB0PerplexityMapper.class); + + private ModelTrainer modelTrainer; + private TopicModel readModel; + private int maxIters; + private int numTopics; + private float testFraction; + private Random random; + private Vector topicVector; + private final DoubleWritable outKey = new DoubleWritable(); + private final DoubleWritable outValue = new DoubleWritable(); + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + MemoryUtil.startMemoryLogger(5000); + + log.info("Retrieving configuration"); + Configuration conf = context.getConfiguration(); + float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); + float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); + long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); + random = RandomUtils.getRandom(seed); + numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); + int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); + int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); + int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); + maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); + float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); + testFraction = conf.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f); + + log.info("Initializing read model"); + Path[] modelPaths = CVB0Driver.getModelPaths(conf); + if (modelPaths != null && modelPaths.length > 0) { + readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); + } else { + log.info("No model files found"); + readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, + numTrainThreads, modelWeight); + } + + log.info("Initializing model trainer"); + modelTrainer = new ModelTrainer(readModel, null, numTrainThreads, numTopics, numTerms); + + log.info("Initializing topic vector"); + topicVector = new DenseVector(new double[numTopics]); + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + readModel.stop(); + MemoryUtil.stopMemoryLogger(); + } + + @Override + public void map(IntWritable docId, VectorWritable document, Context context) + throws IOException, InterruptedException { + if (testFraction < 1.0f && random.nextFloat() >= testFraction) { + return; + } + context.getCounter(Counters.SAMPLED_DOCUMENTS).increment(1); + outKey.set(document.get().norm(1)); + outValue.set(modelTrainer.calculatePerplexity(document.get(), topicVector.assign(1.0 / numTopics), maxIters)); + context.write(outKey, outValue); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java new file mode 100644 index 0000000..d7d09c5 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java @@ -0,0 +1,492 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.DistributedRowMatrixWriter; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Runs the same algorithm as {@link CVB0Driver}, but sequentially, in memory. Memory requirements + * are currently: the entire corpus is read into RAM, two copies of the model (each of size + * numTerms * numTopics), and another matrix of size numDocs * numTopics is held in memory + * (to store p(topic|doc) for all docs). + * + * But if all this fits in memory, this can be significantly faster than an iterative MR job. + */ +public class InMemoryCollapsedVariationalBayes0 extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(InMemoryCollapsedVariationalBayes0.class); + + private int numTopics; + private int numTerms; + private int numDocuments; + private double alpha; + private double eta; + //private int minDfCt; + //private double maxDfPct; + private boolean verbose = false; + private String[] terms; // of length numTerms; + private Matrix corpusWeights; // length numDocs; + private double totalCorpusWeight; + private double initialModelCorpusFraction; + private Matrix docTopicCounts; + private int numTrainingThreads; + private int numUpdatingThreads; + private ModelTrainer modelTrainer; + + private InMemoryCollapsedVariationalBayes0() { + // only for main usage + } + + public void setVerbose(boolean verbose) { + this.verbose = verbose; + } + + public InMemoryCollapsedVariationalBayes0(Matrix corpus, + String[] terms, + int numTopics, + double alpha, + double eta, + int numTrainingThreads, + int numUpdatingThreads, + double modelCorpusFraction) { + //this.seed = seed; + this.numTopics = numTopics; + this.alpha = alpha; + this.eta = eta; + //this.minDfCt = 0; + //this.maxDfPct = 1.0f; + corpusWeights = corpus; + numDocuments = corpus.numRows(); + this.terms = terms; + this.initialModelCorpusFraction = modelCorpusFraction; + numTerms = terms != null ? terms.length : corpus.numCols(); + Map<String, Integer> termIdMap = new HashMap<>(); + if (terms != null) { + for (int t = 0; t < terms.length; t++) { + termIdMap.put(terms[t], t); + } + } + this.numTrainingThreads = numTrainingThreads; + this.numUpdatingThreads = numUpdatingThreads; + postInitCorpus(); + initializeModel(); + } + + private void postInitCorpus() { + totalCorpusWeight = 0; + int numNonZero = 0; + for (int i = 0; i < numDocuments; i++) { + Vector v = corpusWeights.viewRow(i); + double norm; + if (v != null && (norm = v.norm(1)) != 0) { + numNonZero += v.getNumNondefaultElements(); + totalCorpusWeight += norm; + } + } + String s = "Initializing corpus with %d docs, %d terms, %d nonzero entries, total termWeight %f"; + log.info(String.format(s, numDocuments, numTerms, numNonZero, totalCorpusWeight)); + } + + private void initializeModel() { + TopicModel topicModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(), terms, + numUpdatingThreads, initialModelCorpusFraction == 0 ? 1 : initialModelCorpusFraction * totalCorpusWeight); + topicModel.setConf(getConf()); + + TopicModel updatedModel = initialModelCorpusFraction == 0 + ? new TopicModel(numTopics, numTerms, eta, alpha, null, terms, numUpdatingThreads, 1) + : topicModel; + updatedModel.setConf(getConf()); + docTopicCounts = new DenseMatrix(numDocuments, numTopics); + docTopicCounts.assign(1.0 / numTopics); + modelTrainer = new ModelTrainer(topicModel, updatedModel, numTrainingThreads, numTopics, numTerms); + } + + /* + private void inferDocuments(double convergence, int maxIter, boolean recalculate) { + for (int docId = 0; docId < corpusWeights.numRows() ; docId++) { + Vector inferredDocument = topicModel.infer(corpusWeights.viewRow(docId), + docTopicCounts.viewRow(docId)); + // do what now? + } + } + */ + + public void trainDocuments() { + trainDocuments(0); + } + + public void trainDocuments(double testFraction) { + long start = System.nanoTime(); + modelTrainer.start(); + for (int docId = 0; docId < corpusWeights.numRows(); docId++) { + if (testFraction == 0 || docId % (1 / testFraction) != 0) { + Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); // docTopicCounts.getRow(docId) + modelTrainer.trainSync(corpusWeights.viewRow(docId), docTopics , true, 10); + } + } + modelTrainer.stop(); + logTime("train documents", System.nanoTime() - start); + } + + /* + private double error(int docId) { + Vector docTermCounts = corpusWeights.viewRow(docId); + if (docTermCounts == null) { + return 0; + } else { + Vector expectedDocTermCounts = + topicModel.infer(corpusWeights.viewRow(docId), docTopicCounts.viewRow(docId)); + double expectedNorm = expectedDocTermCounts.norm(1); + return expectedDocTermCounts.times(docTermCounts.norm(1)/expectedNorm) + .minus(docTermCounts).norm(1); + } + } + + private double error() { + long time = System.nanoTime(); + double error = 0; + for (int docId = 0; docId < numDocuments; docId++) { + error += error(docId); + } + logTime("error calculation", System.nanoTime() - time); + return error / totalCorpusWeight; + } + */ + + public double iterateUntilConvergence(double minFractionalErrorChange, + int maxIterations, int minIter) { + return iterateUntilConvergence(minFractionalErrorChange, maxIterations, minIter, 0); + } + + public double iterateUntilConvergence(double minFractionalErrorChange, + int maxIterations, int minIter, double testFraction) { + int iter = 0; + double oldPerplexity = 0; + while (iter < minIter) { + trainDocuments(testFraction); + if (verbose) { + log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); + } + log.info("iteration {} complete", iter); + oldPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, + testFraction); + log.info("{} = perplexity", oldPerplexity); + iter++; + } + double newPerplexity = 0; + double fractionalChange = Double.MAX_VALUE; + while (iter < maxIterations && fractionalChange > minFractionalErrorChange) { + trainDocuments(); + if (verbose) { + log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); + } + newPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, + testFraction); + log.info("{} = perplexity", newPerplexity); + iter++; + fractionalChange = Math.abs(newPerplexity - oldPerplexity) / oldPerplexity; + log.info("{} = fractionalChange", fractionalChange); + oldPerplexity = newPerplexity; + } + if (iter < maxIterations) { + log.info(String.format("Converged! fractional error change: %f, error %f", + fractionalChange, newPerplexity)); + } else { + log.info(String.format("Reached max iteration count (%d), fractional error change: %f, error: %f", + maxIterations, fractionalChange, newPerplexity)); + } + return newPerplexity; + } + + public void writeModel(Path outputPath) throws IOException { + modelTrainer.persist(outputPath); + } + + private static void logTime(String label, long nanos) { + log.info("{} time: {}ms", label, nanos / 1.0e6); + } + + public static int main2(String[] args, Configuration conf) throws Exception { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option helpOpt = DefaultOptionCreator.helpOption(); + + Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument( + abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription( + "The Directory on HDFS containing the collapsed, properly formatted files having " + + "one doc per line").withShortName("i").create(); + + Option dictOpt = obuilder.withLongName("dictionary").withRequired(false).withArgument( + abuilder.withName("dictionary").withMinimum(1).withMaximum(1).create()).withDescription( + "The path to the term-dictionary format is ... ").withShortName("d").create(); + + Option dfsOpt = obuilder.withLongName("dfs").withRequired(false).withArgument( + abuilder.withName("dfs").withMinimum(1).withMaximum(1).create()).withDescription( + "HDFS namenode URI").withShortName("dfs").create(); + + Option numTopicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(abuilder + .withName("numTopics").withMinimum(1).withMaximum(1) + .create()).withDescription("Number of topics to learn").withShortName("top").create(); + + Option outputTopicFileOpt = obuilder.withLongName("topicOutputFile").withRequired(true).withArgument( + abuilder.withName("topicOutputFile").withMinimum(1).withMaximum(1).create()) + .withDescription("File to write out p(term | topic)").withShortName("to").create(); + + Option outputDocFileOpt = obuilder.withLongName("docOutputFile").withRequired(true).withArgument( + abuilder.withName("docOutputFile").withMinimum(1).withMaximum(1).create()) + .withDescription("File to write out p(topic | docid)").withShortName("do").create(); + + Option alphaOpt = obuilder.withLongName("alpha").withRequired(false).withArgument(abuilder + .withName("alpha").withMinimum(1).withMaximum(1).withDefault("0.1").create()) + .withDescription("Smoothing parameter for p(topic | document) prior").withShortName("a").create(); + + Option etaOpt = obuilder.withLongName("eta").withRequired(false).withArgument(abuilder + .withName("eta").withMinimum(1).withMaximum(1).withDefault("0.1").create()) + .withDescription("Smoothing parameter for p(term | topic)").withShortName("e").create(); + + Option maxIterOpt = obuilder.withLongName("maxIterations").withRequired(false).withArgument(abuilder + .withName("maxIterations").withMinimum(1).withMaximum(1).withDefault("10").create()) + .withDescription("Maximum number of training passes").withShortName("m").create(); + + Option modelCorpusFractionOption = obuilder.withLongName("modelCorpusFraction") + .withRequired(false).withArgument(abuilder.withName("modelCorpusFraction").withMinimum(1) + .withMaximum(1).withDefault("0.0").create()).withShortName("mcf") + .withDescription("For online updates, initial value of |model|/|corpus|").create(); + + Option burnInOpt = obuilder.withLongName("burnInIterations").withRequired(false).withArgument(abuilder + .withName("burnInIterations").withMinimum(1).withMaximum(1).withDefault("5").create()) + .withDescription("Minimum number of iterations").withShortName("b").create(); + + Option convergenceOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(abuilder + .withName("convergence").withMinimum(1).withMaximum(1).withDefault("0.0").create()) + .withDescription("Fractional rate of perplexity to consider convergence").withShortName("c").create(); + + Option reInferDocTopicsOpt = obuilder.withLongName("reInferDocTopics").withRequired(false) + .withArgument(abuilder.withName("reInferDocTopics").withMinimum(1).withMaximum(1) + .withDefault("no").create()) + .withDescription("re-infer p(topic | doc) : [no | randstart | continue]") + .withShortName("rdt").create(); + + Option numTrainThreadsOpt = obuilder.withLongName("numTrainThreads").withRequired(false) + .withArgument(abuilder.withName("numTrainThreads").withMinimum(1).withMaximum(1) + .withDefault("1").create()) + .withDescription("number of threads to train with") + .withShortName("ntt").create(); + + Option numUpdateThreadsOpt = obuilder.withLongName("numUpdateThreads").withRequired(false) + .withArgument(abuilder.withName("numUpdateThreads").withMinimum(1).withMaximum(1) + .withDefault("1").create()) + .withDescription("number of threads to update the model with") + .withShortName("nut").create(); + + Option verboseOpt = obuilder.withLongName("verbose").withRequired(false) + .withArgument(abuilder.withName("verbose").withMinimum(1).withMaximum(1) + .withDefault("false").create()) + .withDescription("print verbose information, like top-terms in each topic, during iteration") + .withShortName("v").create(); + + Group group = gbuilder.withName("Options").withOption(inputDirOpt).withOption(numTopicsOpt) + .withOption(alphaOpt).withOption(etaOpt) + .withOption(maxIterOpt).withOption(burnInOpt).withOption(convergenceOpt) + .withOption(dictOpt).withOption(reInferDocTopicsOpt) + .withOption(outputDocFileOpt).withOption(outputTopicFileOpt).withOption(dfsOpt) + .withOption(numTrainThreadsOpt).withOption(numUpdateThreadsOpt) + .withOption(modelCorpusFractionOption).withOption(verboseOpt).create(); + + try { + Parser parser = new Parser(); + + parser.setGroup(group); + parser.setHelpOption(helpOpt); + CommandLine cmdLine = parser.parse(args); + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return -1; + } + + String inputDirString = (String) cmdLine.getValue(inputDirOpt); + String dictDirString = cmdLine.hasOption(dictOpt) ? (String)cmdLine.getValue(dictOpt) : null; + int numTopics = Integer.parseInt((String) cmdLine.getValue(numTopicsOpt)); + double alpha = Double.parseDouble((String)cmdLine.getValue(alphaOpt)); + double eta = Double.parseDouble((String)cmdLine.getValue(etaOpt)); + int maxIterations = Integer.parseInt((String)cmdLine.getValue(maxIterOpt)); + int burnInIterations = Integer.parseInt((String)cmdLine.getValue(burnInOpt)); + double minFractionalErrorChange = Double.parseDouble((String) cmdLine.getValue(convergenceOpt)); + int numTrainThreads = Integer.parseInt((String)cmdLine.getValue(numTrainThreadsOpt)); + int numUpdateThreads = Integer.parseInt((String)cmdLine.getValue(numUpdateThreadsOpt)); + String topicOutFile = (String)cmdLine.getValue(outputTopicFileOpt); + String docOutFile = (String)cmdLine.getValue(outputDocFileOpt); + //String reInferDocTopics = (String)cmdLine.getValue(reInferDocTopicsOpt); + boolean verbose = Boolean.parseBoolean((String) cmdLine.getValue(verboseOpt)); + double modelCorpusFraction = Double.parseDouble((String)cmdLine.getValue(modelCorpusFractionOption)); + + long start = System.nanoTime(); + + if (conf.get("fs.default.name") == null) { + String dfsNameNode = (String)cmdLine.getValue(dfsOpt); + conf.set("fs.default.name", dfsNameNode); + } + String[] terms = loadDictionary(dictDirString, conf); + logTime("dictionary loading", System.nanoTime() - start); + start = System.nanoTime(); + Matrix corpus = loadVectors(inputDirString, conf); + logTime("vector seqfile corpus loading", System.nanoTime() - start); + start = System.nanoTime(); + InMemoryCollapsedVariationalBayes0 cvb0 = + new InMemoryCollapsedVariationalBayes0(corpus, terms, numTopics, alpha, eta, + numTrainThreads, numUpdateThreads, modelCorpusFraction); + logTime("cvb0 init", System.nanoTime() - start); + + start = System.nanoTime(); + cvb0.setVerbose(verbose); + cvb0.iterateUntilConvergence(minFractionalErrorChange, maxIterations, burnInIterations); + logTime("total training time", System.nanoTime() - start); + + /* + if ("randstart".equalsIgnoreCase(reInferDocTopics)) { + cvb0.inferDocuments(0.0, 100, true); + } else if ("continue".equalsIgnoreCase(reInferDocTopics)) { + cvb0.inferDocuments(0.0, 100, false); + } + */ + + start = System.nanoTime(); + cvb0.writeModel(new Path(topicOutFile)); + DistributedRowMatrixWriter.write(new Path(docOutFile), conf, cvb0.docTopicCounts); + logTime("printTopics", System.nanoTime() - start); + } catch (OptionException e) { + log.error("Error while parsing options", e); + CommandLineUtil.printHelp(group); + } + return 0; + } + + private static String[] loadDictionary(String dictionaryPath, Configuration conf) { + if (dictionaryPath == null) { + return null; + } + Path dictionaryFile = new Path(dictionaryPath); + List<Pair<Integer, String>> termList = new ArrayList<>(); + int maxTermId = 0; + // key is word value is id + for (Pair<Writable, IntWritable> record + : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) { + termList.add(new Pair<>(record.getSecond().get(), + record.getFirst().toString())); + maxTermId = Math.max(maxTermId, record.getSecond().get()); + } + String[] terms = new String[maxTermId + 1]; + for (Pair<Integer, String> pair : termList) { + terms[pair.getFirst()] = pair.getSecond(); + } + return terms; + } + + @Override + public Configuration getConf() { + return super.getConf(); + } + + private static Matrix loadVectors(String vectorPathString, Configuration conf) + throws IOException { + Path vectorPath = new Path(vectorPathString); + FileSystem fs = vectorPath.getFileSystem(conf); + List<Path> subPaths = new ArrayList<>(); + if (fs.isFile(vectorPath)) { + subPaths.add(vectorPath); + } else { + for (FileStatus fileStatus : fs.listStatus(vectorPath, PathFilters.logsCRCFilter())) { + subPaths.add(fileStatus.getPath()); + } + } + List<Pair<Integer, Vector>> rowList = new ArrayList<>(); + int numRows = Integer.MIN_VALUE; + int numCols = -1; + boolean sequentialAccess = false; + for (Path subPath : subPaths) { + for (Pair<IntWritable, VectorWritable> record + : new SequenceFileIterable<IntWritable, VectorWritable>(subPath, true, conf)) { + int id = record.getFirst().get(); + Vector vector = record.getSecond().get(); + if (vector instanceof NamedVector) { + vector = ((NamedVector)vector).getDelegate(); + } + if (numCols < 0) { + numCols = vector.size(); + sequentialAccess = vector.isSequentialAccess(); + } + rowList.add(Pair.of(id, vector)); + numRows = Math.max(numRows, id); + } + } + numRows++; + Vector[] rowVectors = new Vector[numRows]; + for (Pair<Integer, Vector> pair : rowList) { + rowVectors[pair.getFirst()] = pair.getSecond(); + } + return new SparseRowMatrix(numRows, numCols, rowVectors, true, !sequentialAccess); + + } + + @Override + public int run(String[] strings) throws Exception { + return main2(strings, getConf()); + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new InMemoryCollapsedVariationalBayes0(), args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java new file mode 100644 index 0000000..c3f2bc0 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java @@ -0,0 +1,301 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import org.apache.hadoop.fs.Path; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Multithreaded LDA model trainer class, which primarily operates by running a "map/reduce" + * operation, all in memory locally (ie not a hadoop job!) : the "map" operation is to take + * the "read-only" {@link TopicModel} and use it to iteratively learn the p(topic|term, doc) + * distribution for documents (this can be done in parallel across many documents, as the + * "read-only" model is, well, read-only. Then the outputs of this are "reduced" onto the + * "write" model, and these updates are not parallelizable in the same way: individual + * documents can't be added to the same entries in different threads at the same time, but + * updates across many topics to the same term from the same document can be done in parallel, + * so they are. + * + * Because computation is done asynchronously, when iteration is done, it's important to call + * the stop() method, which blocks until work is complete. + * + * Setting the read model and the write model to be the same object may not quite work yet, + * on account of parallelism badness. + */ +public class ModelTrainer { + + private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class); + + private final int numTopics; + private final int numTerms; + private TopicModel readModel; + private TopicModel writeModel; + private ThreadPoolExecutor threadPool; + private BlockingQueue<Runnable> workQueue; + private final int numTrainThreads; + private final boolean isReadWrite; + + public ModelTrainer(TopicModel initialReadModel, TopicModel initialWriteModel, + int numTrainThreads, int numTopics, int numTerms) { + this.readModel = initialReadModel; + this.writeModel = initialWriteModel; + this.numTrainThreads = numTrainThreads; + this.numTopics = numTopics; + this.numTerms = numTerms; + isReadWrite = initialReadModel == initialWriteModel; + } + + /** + * WARNING: this constructor may not lead to good behavior. What should be verified is that + * the model updating process does not conflict with model reading. It might work, but then + * again, it might not! + * @param model to be used for both reading (inference) and accumulating (learning) + * @param numTrainThreads + * @param numTopics + * @param numTerms + */ + public ModelTrainer(TopicModel model, int numTrainThreads, int numTopics, int numTerms) { + this(model, model, numTrainThreads, numTopics, numTerms); + } + + public TopicModel getReadModel() { + return readModel; + } + + public void start() { + log.info("Starting training threadpool with {} threads", numTrainThreads); + workQueue = new ArrayBlockingQueue<>(numTrainThreads * 10); + threadPool = new ThreadPoolExecutor(numTrainThreads, numTrainThreads, 0, TimeUnit.SECONDS, + workQueue); + threadPool.allowCoreThreadTimeOut(false); + threadPool.prestartAllCoreThreads(); + writeModel.reset(); + } + + public void train(VectorIterable matrix, VectorIterable docTopicCounts) { + train(matrix, docTopicCounts, 1); + } + + public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) { + return calculatePerplexity(matrix, docTopicCounts, 0); + } + + public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts, + double testFraction) { + Iterator<MatrixSlice> docIterator = matrix.iterator(); + Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator(); + double perplexity = 0; + double matrixNorm = 0; + while (docIterator.hasNext() && docTopicIterator.hasNext()) { + MatrixSlice docSlice = docIterator.next(); + MatrixSlice topicSlice = docTopicIterator.next(); + int docId = docSlice.index(); + Vector document = docSlice.vector(); + Vector topicDist = topicSlice.vector(); + if (testFraction == 0 || docId % (1 / testFraction) == 0) { + trainSync(document, topicDist, false, 10); + perplexity += readModel.perplexity(document, topicDist); + matrixNorm += document.norm(1); + } + } + return perplexity / matrixNorm; + } + + public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) { + start(); + Iterator<MatrixSlice> docIterator = matrix.iterator(); + Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator(); + long startTime = System.nanoTime(); + int i = 0; + double[] times = new double[100]; + Map<Vector, Vector> batch = new HashMap<>(); + int numTokensInBatch = 0; + long batchStart = System.nanoTime(); + while (docIterator.hasNext() && docTopicIterator.hasNext()) { + i++; + Vector document = docIterator.next().vector(); + Vector topicDist = docTopicIterator.next().vector(); + if (isReadWrite) { + if (batch.size() < numTrainThreads) { + batch.put(document, topicDist); + if (log.isDebugEnabled()) { + numTokensInBatch += document.getNumNondefaultElements(); + } + } else { + batchTrain(batch, true, numDocTopicIters); + long time = System.nanoTime(); + log.debug("trained {} docs with {} tokens, start time {}, end time {}", + numTrainThreads, numTokensInBatch, batchStart, time); + batchStart = time; + numTokensInBatch = 0; + } + } else { + long start = System.nanoTime(); + train(document, topicDist, true, numDocTopicIters); + if (log.isDebugEnabled()) { + times[i % times.length] = + (System.nanoTime() - start) / (1.0e6 * document.getNumNondefaultElements()); + if (i % 100 == 0) { + long time = System.nanoTime() - startTime; + log.debug("trained {} documents in {}ms", i, time / 1.0e6); + if (i % 500 == 0) { + Arrays.sort(times); + log.debug("training took median {}ms per token-instance", times[times.length / 2]); + } + } + } + } + } + stop(); + } + + public void batchTrain(Map<Vector, Vector> batch, boolean update, int numDocTopicsIters) { + while (true) { + try { + List<TrainerRunnable> runnables = new ArrayList<>(); + for (Map.Entry<Vector, Vector> entry : batch.entrySet()) { + runnables.add(new TrainerRunnable(readModel, null, entry.getKey(), + entry.getValue(), new SparseRowMatrix(numTopics, numTerms, true), + numDocTopicsIters)); + } + threadPool.invokeAll(runnables); + if (update) { + for (TrainerRunnable runnable : runnables) { + writeModel.update(runnable.docTopicModel); + } + } + break; + } catch (InterruptedException e) { + log.warn("Interrupted during batch training, retrying!", e); + } + } + } + + public void train(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) { + while (true) { + try { + workQueue.put(new TrainerRunnable(readModel, update + ? writeModel + : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters)); + return; + } catch (InterruptedException e) { + log.warn("Interrupted waiting to submit document to work queue: {}", document, e); + } + } + } + + public void trainSync(Vector document, Vector docTopicCounts, boolean update, + int numDocTopicIters) { + new TrainerRunnable(readModel, update + ? writeModel + : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters).run(); + } + + public double calculatePerplexity(Vector document, Vector docTopicCounts, int numDocTopicIters) { + TrainerRunnable runner = new TrainerRunnable(readModel, null, document, docTopicCounts, + new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters); + return runner.call(); + } + + public void stop() { + long startTime = System.nanoTime(); + log.info("Initiating stopping of training threadpool"); + try { + threadPool.shutdown(); + if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) { + log.warn("Threadpool timed out on await termination - jobs still running!"); + } + long newTime = System.nanoTime(); + log.info("threadpool took: {}ms", (newTime - startTime) / 1.0e6); + startTime = newTime; + readModel.stop(); + newTime = System.nanoTime(); + log.info("readModel.stop() took {}ms", (newTime - startTime) / 1.0e6); + startTime = newTime; + writeModel.stop(); + newTime = System.nanoTime(); + log.info("writeModel.stop() took {}ms", (newTime - startTime) / 1.0e6); + TopicModel tmpModel = writeModel; + writeModel = readModel; + readModel = tmpModel; + } catch (InterruptedException e) { + log.error("Interrupted shutting down!", e); + } + } + + public void persist(Path outputPath) throws IOException { + readModel.persist(outputPath, true); + } + + private static final class TrainerRunnable implements Runnable, Callable<Double> { + private final TopicModel readModel; + private final TopicModel writeModel; + private final Vector document; + private final Vector docTopics; + private final Matrix docTopicModel; + private final int numDocTopicIters; + + private TrainerRunnable(TopicModel readModel, TopicModel writeModel, Vector document, + Vector docTopics, Matrix docTopicModel, int numDocTopicIters) { + this.readModel = readModel; + this.writeModel = writeModel; + this.document = document; + this.docTopics = docTopics; + this.docTopicModel = docTopicModel; + this.numDocTopicIters = numDocTopicIters; + } + + @Override + public void run() { + for (int i = 0; i < numDocTopicIters; i++) { + // synchronous read-only call: + readModel.trainDocTopicModel(document, docTopics, docTopicModel); + } + if (writeModel != null) { + // parallel call which is read-only on the docTopicModel, and write-only on the writeModel + // this method does not return until all rows of the docTopicModel have been submitted + // to write work queues + writeModel.update(docTopicModel); + } + } + + @Override + public Double call() { + run(); + return readModel.perplexity(document, docTopics); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java new file mode 100644 index 0000000..9ba77c1 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/TopicModel.java @@ -0,0 +1,513 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Random; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import org.apache.hadoop.conf.Configurable; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.DistributedRowMatrixWriter; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.stats.Sampler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Thin wrapper around a {@link Matrix} of counts of occurrences of (topic, term) pairs. Dividing + * {code topicTermCount.viewRow(topic).get(term)} by the sum over the values for all terms in that + * row yields p(term | topic). Instead dividing it by all topic columns for that term yields + * p(topic | term). + * + * Multithreading is enabled for the {@code update(Matrix)} method: this method is async, and + * merely submits the matrix to a work queue. When all work has been submitted, + * {@code awaitTermination()} should be called, which will block until updates have been + * accumulated. + */ +public class TopicModel implements Configurable, Iterable<MatrixSlice> { + + private static final Logger log = LoggerFactory.getLogger(TopicModel.class); + + private final String[] dictionary; + private final Matrix topicTermCounts; + private final Vector topicSums; + private final int numTopics; + private final int numTerms; + private final double eta; + private final double alpha; + + private Configuration conf; + + private final Sampler sampler; + private final int numThreads; + private ThreadPoolExecutor threadPool; + private Updater[] updaters; + + public int getNumTerms() { + return numTerms; + } + + public int getNumTopics() { + return numTopics; + } + + public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary, + double modelWeight) { + this(numTopics, numTerms, eta, alpha, null, dictionary, 1, modelWeight); + } + + public TopicModel(Configuration conf, double eta, double alpha, + String[] dictionary, int numThreads, double modelWeight, Path... modelpath) throws IOException { + this(loadModel(conf, modelpath), eta, alpha, dictionary, numThreads, modelWeight); + } + + public TopicModel(int numTopics, int numTerms, double eta, double alpha, String[] dictionary, + int numThreads, double modelWeight) { + this(new DenseMatrix(numTopics, numTerms), new DenseVector(numTopics), eta, alpha, dictionary, + numThreads, modelWeight); + } + + public TopicModel(int numTopics, int numTerms, double eta, double alpha, Random random, + String[] dictionary, int numThreads, double modelWeight) { + this(randomMatrix(numTopics, numTerms, random), eta, alpha, dictionary, numThreads, modelWeight); + } + + private TopicModel(Pair<Matrix, Vector> model, double eta, double alpha, String[] dict, + int numThreads, double modelWeight) { + this(model.getFirst(), model.getSecond(), eta, alpha, dict, numThreads, modelWeight); + } + + public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha, + String[] dictionary, double modelWeight) { + this(topicTermCounts, topicSums, eta, alpha, dictionary, 1, modelWeight); + } + + public TopicModel(Matrix topicTermCounts, double eta, double alpha, String[] dictionary, + int numThreads, double modelWeight) { + this(topicTermCounts, viewRowSums(topicTermCounts), + eta, alpha, dictionary, numThreads, modelWeight); + } + + public TopicModel(Matrix topicTermCounts, Vector topicSums, double eta, double alpha, + String[] dictionary, int numThreads, double modelWeight) { + this.dictionary = dictionary; + this.topicTermCounts = topicTermCounts; + this.topicSums = topicSums; + this.numTopics = topicSums.size(); + this.numTerms = topicTermCounts.numCols(); + this.eta = eta; + this.alpha = alpha; + this.sampler = new Sampler(RandomUtils.getRandom()); + this.numThreads = numThreads; + if (modelWeight != 1) { + topicSums.assign(Functions.mult(modelWeight)); + for (int x = 0; x < numTopics; x++) { + topicTermCounts.viewRow(x).assign(Functions.mult(modelWeight)); + } + } + initializeThreadPool(); + } + + private static Vector viewRowSums(Matrix m) { + Vector v = new DenseVector(m.numRows()); + for (MatrixSlice slice : m) { + v.set(slice.index(), slice.vector().norm(1)); + } + return v; + } + + private synchronized void initializeThreadPool() { + if (threadPool != null) { + threadPool.shutdown(); + try { + threadPool.awaitTermination(100, TimeUnit.SECONDS); + } catch (InterruptedException e) { + log.error("Could not terminate all threads for TopicModel in time.", e); + } + } + threadPool = new ThreadPoolExecutor(numThreads, numThreads, 0, TimeUnit.SECONDS, + new ArrayBlockingQueue<Runnable>(numThreads * 10)); + threadPool.allowCoreThreadTimeOut(false); + updaters = new Updater[numThreads]; + for (int i = 0; i < numThreads; i++) { + updaters[i] = new Updater(); + threadPool.submit(updaters[i]); + } + } + + Matrix topicTermCounts() { + return topicTermCounts; + } + + @Override + public Iterator<MatrixSlice> iterator() { + return topicTermCounts.iterateAll(); + } + + public Vector topicSums() { + return topicSums; + } + + private static Pair<Matrix,Vector> randomMatrix(int numTopics, int numTerms, Random random) { + Matrix topicTermCounts = new DenseMatrix(numTopics, numTerms); + Vector topicSums = new DenseVector(numTopics); + if (random != null) { + for (int x = 0; x < numTopics; x++) { + for (int term = 0; term < numTerms; term++) { + topicTermCounts.viewRow(x).set(term, random.nextDouble()); + } + } + } + for (int x = 0; x < numTopics; x++) { + topicSums.set(x, random == null ? 1.0 : topicTermCounts.viewRow(x).norm(1)); + } + return Pair.of(topicTermCounts, topicSums); + } + + public static Pair<Matrix, Vector> loadModel(Configuration conf, Path... modelPaths) + throws IOException { + int numTopics = -1; + int numTerms = -1; + List<Pair<Integer, Vector>> rows = new ArrayList<>(); + for (Path modelPath : modelPaths) { + for (Pair<IntWritable, VectorWritable> row + : new SequenceFileIterable<IntWritable, VectorWritable>(modelPath, true, conf)) { + rows.add(Pair.of(row.getFirst().get(), row.getSecond().get())); + numTopics = Math.max(numTopics, row.getFirst().get()); + if (numTerms < 0) { + numTerms = row.getSecond().get().size(); + } + } + } + if (rows.isEmpty()) { + throw new IOException(Arrays.toString(modelPaths) + " have no vectors in it"); + } + numTopics++; + Matrix model = new DenseMatrix(numTopics, numTerms); + Vector topicSums = new DenseVector(numTopics); + for (Pair<Integer, Vector> pair : rows) { + model.viewRow(pair.getFirst()).assign(pair.getSecond()); + topicSums.set(pair.getFirst(), pair.getSecond().norm(1)); + } + return Pair.of(model, topicSums); + } + + // NOTE: this is purely for debug purposes. It is not performant to "toString()" a real model + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + for (int x = 0; x < numTopics; x++) { + String v = dictionary != null + ? vectorToSortedString(topicTermCounts.viewRow(x).normalize(1), dictionary) + : topicTermCounts.viewRow(x).asFormatString(); + buf.append(v).append('\n'); + } + return buf.toString(); + } + + public int sampleTerm(Vector topicDistribution) { + return sampler.sample(topicTermCounts.viewRow(sampler.sample(topicDistribution))); + } + + public int sampleTerm(int topic) { + return sampler.sample(topicTermCounts.viewRow(topic)); + } + + public synchronized void reset() { + for (int x = 0; x < numTopics; x++) { + topicTermCounts.assignRow(x, new SequentialAccessSparseVector(numTerms)); + } + topicSums.assign(1.0); + if (threadPool.isTerminated()) { + initializeThreadPool(); + } + } + + public synchronized void stop() { + for (Updater updater : updaters) { + updater.shutdown(); + } + threadPool.shutdown(); + try { + if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) { + log.warn("Threadpool timed out on await termination - jobs still running!"); + } + } catch (InterruptedException e) { + log.error("Interrupted shutting down!", e); + } + } + + public void renormalize() { + for (int x = 0; x < numTopics; x++) { + topicTermCounts.assignRow(x, topicTermCounts.viewRow(x).normalize(1)); + topicSums.assign(1.0); + } + } + + public void trainDocTopicModel(Vector original, Vector topics, Matrix docTopicModel) { + // first calculate p(topic|term,document) for all terms in original, and all topics, + // using p(term|topic) and p(topic|doc) + pTopicGivenTerm(original, topics, docTopicModel); + normalizeByTopic(docTopicModel); + // now multiply, term-by-term, by the document, to get the weighted distribution of + // term-topic pairs from this document. + for (Element e : original.nonZeroes()) { + for (int x = 0; x < numTopics; x++) { + Vector docTopicModelRow = docTopicModel.viewRow(x); + docTopicModelRow.setQuick(e.index(), docTopicModelRow.getQuick(e.index()) * e.get()); + } + } + // now recalculate \(p(topic|doc)\) by summing contributions from all of pTopicGivenTerm + topics.assign(0.0); + for (int x = 0; x < numTopics; x++) { + topics.set(x, docTopicModel.viewRow(x).norm(1)); + } + // now renormalize so that \(sum_x(p(x|doc))\) = 1 + topics.assign(Functions.mult(1 / topics.norm(1))); + } + + public Vector infer(Vector original, Vector docTopics) { + Vector pTerm = original.like(); + for (Element e : original.nonZeroes()) { + int term = e.index(); + // p(a) = sum_x (p(a|x) * p(x|i)) + double pA = 0; + for (int x = 0; x < numTopics; x++) { + pA += (topicTermCounts.viewRow(x).get(term) / topicSums.get(x)) * docTopics.get(x); + } + pTerm.set(term, pA); + } + return pTerm; + } + + public void update(Matrix docTopicCounts) { + for (int x = 0; x < numTopics; x++) { + updaters[x % updaters.length].update(x, docTopicCounts.viewRow(x)); + } + } + + public void updateTopic(int topic, Vector docTopicCounts) { + topicTermCounts.viewRow(topic).assign(docTopicCounts, Functions.PLUS); + topicSums.set(topic, topicSums.get(topic) + docTopicCounts.norm(1)); + } + + public void update(int termId, Vector topicCounts) { + for (int x = 0; x < numTopics; x++) { + Vector v = topicTermCounts.viewRow(x); + v.set(termId, v.get(termId) + topicCounts.get(x)); + } + topicSums.assign(topicCounts, Functions.PLUS); + } + + public void persist(Path outputDir, boolean overwrite) throws IOException { + FileSystem fs = outputDir.getFileSystem(conf); + if (overwrite) { + fs.delete(outputDir, true); // CHECK second arg + } + DistributedRowMatrixWriter.write(outputDir, conf, topicTermCounts); + } + + /** + * Computes {@code \(p(topic x | term a, document i)\)} distributions given input document {@code i}. + * {@code \(pTGT[x][a]\)} is the (un-normalized) {@code \(p(x|a,i)\)}, or if docTopics is {@code null}, + * {@code \(p(a|x)\)} (also un-normalized). + * + * @param document doc-term vector encoding {@code \(w(term a|document i)\)}. + * @param docTopics {@code docTopics[x]} is the overall weight of topic {@code x} in given + * document. If {@code null}, a topic weight of {@code 1.0} is used for all topics. + * @param termTopicDist storage for output {@code \(p(x|a,i)\)} distributions. + */ + private void pTopicGivenTerm(Vector document, Vector docTopics, Matrix termTopicDist) { + // for each topic x + for (int x = 0; x < numTopics; x++) { + // get p(topic x | document i), or 1.0 if docTopics is null + double topicWeight = docTopics == null ? 1.0 : docTopics.get(x); + // get w(term a | topic x) + Vector topicTermRow = topicTermCounts.viewRow(x); + // get \sum_a w(term a | topic x) + double topicSum = topicSums.get(x); + // get p(topic x | term a) distribution to update + Vector termTopicRow = termTopicDist.viewRow(x); + + // for each term a in document i with non-zero weight + for (Element e : document.nonZeroes()) { + int termIndex = e.index(); + + // calc un-normalized p(topic x | term a, document i) + double termTopicLikelihood = (topicTermRow.get(termIndex) + eta) * (topicWeight + alpha) + / (topicSum + eta * numTerms); + termTopicRow.set(termIndex, termTopicLikelihood); + } + } + } + + /** + * \(sum_x sum_a (c_ai * log(p(x|i) * p(a|x)))\) + */ + public double perplexity(Vector document, Vector docTopics) { + double perplexity = 0; + double norm = docTopics.norm(1) + (docTopics.size() * alpha); + for (Element e : document.nonZeroes()) { + int term = e.index(); + double prob = 0; + for (int x = 0; x < numTopics; x++) { + double d = (docTopics.get(x) + alpha) / norm; + double p = d * (topicTermCounts.viewRow(x).get(term) + eta) + / (topicSums.get(x) + eta * numTerms); + prob += p; + } + perplexity += e.get() * Math.log(prob); + } + return -perplexity; + } + + private void normalizeByTopic(Matrix perTopicSparseDistributions) { + // then make sure that each of these is properly normalized by topic: sum_x(p(x|t,d)) = 1 + for (Element e : perTopicSparseDistributions.viewRow(0).nonZeroes()) { + int a = e.index(); + double sum = 0; + for (int x = 0; x < numTopics; x++) { + sum += perTopicSparseDistributions.viewRow(x).get(a); + } + for (int x = 0; x < numTopics; x++) { + perTopicSparseDistributions.viewRow(x).set(a, + perTopicSparseDistributions.viewRow(x).get(a) / sum); + } + } + } + + public static String vectorToSortedString(Vector vector, String[] dictionary) { + List<Pair<String,Double>> vectorValues = new ArrayList<>(vector.getNumNondefaultElements()); + for (Element e : vector.nonZeroes()) { + vectorValues.add(Pair.of(dictionary != null ? dictionary[e.index()] : String.valueOf(e.index()), + e.get())); + } + Collections.sort(vectorValues, new Comparator<Pair<String, Double>>() { + @Override public int compare(Pair<String, Double> x, Pair<String, Double> y) { + return y.getSecond().compareTo(x.getSecond()); + } + }); + Iterator<Pair<String,Double>> listIt = vectorValues.iterator(); + StringBuilder bldr = new StringBuilder(2048); + bldr.append('{'); + int i = 0; + while (listIt.hasNext() && i < 25) { + i++; + Pair<String,Double> p = listIt.next(); + bldr.append(p.getFirst()); + bldr.append(':'); + bldr.append(p.getSecond()); + bldr.append(','); + } + if (bldr.length() > 1) { + bldr.setCharAt(bldr.length() - 1, '}'); + } + return bldr.toString(); + } + + @Override + public void setConf(Configuration configuration) { + this.conf = configuration; + } + + @Override + public Configuration getConf() { + return conf; + } + + private final class Updater implements Runnable { + private final ArrayBlockingQueue<Pair<Integer, Vector>> queue = + new ArrayBlockingQueue<>(100); + private boolean shutdown = false; + private boolean shutdownComplete = false; + + public void shutdown() { + try { + synchronized (this) { + while (!shutdownComplete) { + shutdown = true; + wait(10000L); // Arbitrarily, wait 10 seconds rather than forever for this + } + } + } catch (InterruptedException e) { + log.warn("Interrupted waiting to shutdown() : ", e); + } + } + + public boolean update(int topic, Vector v) { + if (shutdown) { // maybe don't do this? + throw new IllegalStateException("In SHUTDOWN state: cannot submit tasks"); + } + while (true) { // keep trying if interrupted + try { + // start async operation by submitting to the queue + queue.put(Pair.of(topic, v)); + // return once you got access to the queue + return true; + } catch (InterruptedException e) { + log.warn("Interrupted trying to queue update:", e); + } + } + } + + @Override + public void run() { + while (!shutdown) { + try { + Pair<Integer, Vector> pair = queue.poll(1, TimeUnit.SECONDS); + if (pair != null) { + updateTopic(pair.getFirst(), pair.getSecond()); + } + } catch (InterruptedException e) { + log.warn("Interrupted waiting to poll for update", e); + } + } + // in shutdown mode, finish remaining tasks! + for (Pair<Integer, Vector> pair : queue) { + updateTopic(pair.getFirst(), pair.getSecond()); + } + synchronized (this) { + shutdownComplete = true; + notifyAll(); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/package-info.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/package-info.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/package-info.java new file mode 100644 index 0000000..9926b91 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/package-info.java @@ -0,0 +1,13 @@ +/** + * <p></p>This package provides several clustering algorithm implementations. Clustering usually groups a set of + * objects into groups of similar items. The definition of similarity usually is up to you - for text documents, + * cosine-distance/-similarity is recommended. Mahout also features other types of distance measure like + * Euclidean distance.</p> + * + * <p></p>Input of each clustering algorithm is a set of vectors representing your items. For texts in general these are + * <a href="http://en.wikipedia.org/wiki/TFIDF">TFIDF</a> or + * <a href="http://en.wikipedia.org/wiki/Bag_of_words">Bag of words</a> representations of the documents.</p> + * + * <p>Output of each clustering algorithm is either a hard or soft assignment of items to clusters.</p> + */ +package org.apache.mahout.clustering; http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java new file mode 100644 index 0000000..aa12b9e --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputJob.java @@ -0,0 +1,84 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; + +public final class AffinityMatrixInputJob { + + private AffinityMatrixInputJob() { + } + + /** + * Initializes and executes the job of reading the documents containing + * the data of the affinity matrix in (x_i, x_j, value) format. + */ + public static void runJob(Path input, Path output, int rows, int cols) + throws IOException, InterruptedException, ClassNotFoundException { + Configuration conf = new Configuration(); + HadoopUtil.delete(conf, output); + + conf.setInt(Keys.AFFINITY_DIMENSIONS, rows); + Job job = new Job(conf, "AffinityMatrixInputJob: " + input + " -> M/R -> " + output); + + job.setMapOutputKeyClass(IntWritable.class); + job.setMapOutputValueClass(DistributedRowMatrix.MatrixEntryWritable.class); + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(VectorWritable.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(AffinityMatrixInputMapper.class); + job.setReducerClass(AffinityMatrixInputReducer.class); + + FileInputFormat.addInputPath(job, input); + FileOutputFormat.setOutputPath(job, output); + + job.setJarByClass(AffinityMatrixInputJob.class); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + + /** + * A transparent wrapper for the above method which handles the tedious tasks + * of setting and retrieving system Paths. Hands back a fully-populated + * and initialized DistributedRowMatrix. + */ + public static DistributedRowMatrix runJob(Path input, Path output, int dimensions) + throws IOException, InterruptedException, ClassNotFoundException { + Path seqFiles = new Path(output, "seqfiles-" + (System.nanoTime() & 0xFF)); + runJob(input, seqFiles, dimensions, dimensions); + DistributedRowMatrix a = new DistributedRowMatrix(seqFiles, + new Path(seqFiles, "seqtmp-" + (System.nanoTime() & 0xFF)), + dimensions, dimensions); + a.setConf(new Configuration()); + return a; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java new file mode 100644 index 0000000..30d2404 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputMapper.java @@ -0,0 +1,78 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +import java.io.IOException; +import java.util.regex.Pattern; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * <p>Handles reading the files representing the affinity matrix. Since the affinity + * matrix is representative of a graph, each line in all the files should + * take the form:</p> + * + * {@code i,j,value} + * + * <p>where {@code i} and {@code j} are the {@code i}th and + * {@code j} data points in the entire set, and {@code value} + * represents some measurement of their relative absolute magnitudes. This + * is, simply, a method for representing a graph textually. + */ +public class AffinityMatrixInputMapper + extends Mapper<LongWritable, Text, IntWritable, DistributedRowMatrix.MatrixEntryWritable> { + + private static final Logger log = LoggerFactory.getLogger(AffinityMatrixInputMapper.class); + + private static final Pattern COMMA_PATTERN = Pattern.compile(","); + + @Override + protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { + + String[] elements = COMMA_PATTERN.split(value.toString()); + log.debug("(DEBUG - MAP) Key[{}], Value[{}]", key.get(), value); + + // enforce well-formed textual representation of the graph + if (elements.length != 3) { + throw new IOException("Expected input of length 3, received " + + elements.length + ". Please make sure you adhere to " + + "the structure of (i,j,value) for representing a graph in text. " + + "Input line was: '" + value + "'."); + } + if (elements[0].isEmpty() || elements[1].isEmpty() || elements[2].isEmpty()) { + throw new IOException("Found an element of 0 length. Please be sure you adhere to the structure of " + + "(i,j,value) for representing a graph in text."); + } + + // parse the line of text into a DistributedRowMatrix entry, + // making the row (elements[0]) the key to the Reducer, and + // setting the column (elements[1]) in the entry itself + DistributedRowMatrix.MatrixEntryWritable toAdd = new DistributedRowMatrix.MatrixEntryWritable(); + IntWritable row = new IntWritable(Integer.valueOf(elements[0])); + toAdd.setRow(-1); // already set as the Reducer's key + toAdd.setCol(Integer.valueOf(elements[1])); + toAdd.setVal(Double.valueOf(elements[2])); + context.write(row, toAdd); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java new file mode 100644 index 0000000..d892969 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/AffinityMatrixInputReducer.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +import java.io.IOException; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Tasked with taking each DistributedRowMatrix entry and collecting them + * into vectors corresponding to rows. The input and output keys are the same, + * corresponding to the row in the ensuing matrix. The matrix entries are + * entered into a vector according to the column to which they belong, and + * the vector is then given the key corresponding to its row. + */ +public class AffinityMatrixInputReducer + extends Reducer<IntWritable, DistributedRowMatrix.MatrixEntryWritable, IntWritable, VectorWritable> { + + private static final Logger log = LoggerFactory.getLogger(AffinityMatrixInputReducer.class); + + @Override + protected void reduce(IntWritable row, Iterable<DistributedRowMatrix.MatrixEntryWritable> values, Context context) + throws IOException, InterruptedException { + int size = context.getConfiguration().getInt(Keys.AFFINITY_DIMENSIONS, Integer.MAX_VALUE); + RandomAccessSparseVector out = new RandomAccessSparseVector(size, 100); + + for (DistributedRowMatrix.MatrixEntryWritable element : values) { + out.setQuick(element.getCol(), element.getVal()); + if (log.isDebugEnabled()) { + log.debug("(DEBUG - REDUCE) Row[{}], Column[{}], Value[{}]", + row.get(), element.getCol(), element.getVal()); + } + } + SequentialAccessSparseVector output = new SequentialAccessSparseVector(out); + context.write(row, new VectorWritable(output)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java new file mode 100644 index 0000000..593cc58 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/IntDoublePairWritable.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; + +/** + * This class is a Writable implementation of the mahout.common.Pair + * generic class. Since the generic types would also themselves have to + * implement Writable, it made more sense to create a more specialized + * version of the class altogether. + * + * In essence, this can be treated as a single Vector Element. + */ +public class IntDoublePairWritable implements Writable { + + private int key; + private double value; + + public IntDoublePairWritable() { + } + + public IntDoublePairWritable(int k, double v) { + this.key = k; + this.value = v; + } + + public void setKey(int k) { + this.key = k; + } + + public void setValue(double v) { + this.value = v; + } + + @Override + public void readFields(DataInput in) throws IOException { + this.key = in.readInt(); + this.value = in.readDouble(); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(key); + out.writeDouble(value); + } + + public int getKey() { + return key; + } + + public double getValue() { + return value; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java new file mode 100644 index 0000000..268a365 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/Keys.java @@ -0,0 +1,31 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +public class Keys { + + /** + * Sets the SequenceFile index for the diagonal matrix. + */ + public static final int DIAGONAL_CACHE_INDEX = 1; + + public static final String AFFINITY_DIMENSIONS = "org.apache.mahout.clustering.spectral.common.affinitydimensions"; + + private Keys() {} + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java new file mode 100644 index 0000000..f245f99 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/spectral/MatrixDiagonalizeJob.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +/** + * Given a matrix, this job returns a vector whose i_th element is the + * sum of all the elements in the i_th row of the original matrix. + */ +public final class MatrixDiagonalizeJob { + + private MatrixDiagonalizeJob() { + } + + public static Vector runJob(Path affInput, int dimensions) + throws IOException, ClassNotFoundException, InterruptedException { + + // set up all the job tasks + Configuration conf = new Configuration(); + Path diagOutput = new Path(affInput.getParent(), "diagonal"); + HadoopUtil.delete(conf, diagOutput); + conf.setInt(Keys.AFFINITY_DIMENSIONS, dimensions); + Job job = new Job(conf, "MatrixDiagonalizeJob"); + + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setMapOutputKeyClass(NullWritable.class); + job.setMapOutputValueClass(IntDoublePairWritable.class); + job.setOutputKeyClass(NullWritable.class); + job.setOutputValueClass(VectorWritable.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(MatrixDiagonalizeMapper.class); + job.setReducerClass(MatrixDiagonalizeReducer.class); + + FileInputFormat.addInputPath(job, affInput); + FileOutputFormat.setOutputPath(job, diagOutput); + + job.setJarByClass(MatrixDiagonalizeJob.class); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + // read the results back from the path + return VectorCache.load(conf, new Path(diagOutput, "part-r-00000")); + } + + public static class MatrixDiagonalizeMapper + extends Mapper<IntWritable, VectorWritable, NullWritable, IntDoublePairWritable> { + + @Override + protected void map(IntWritable key, VectorWritable row, Context context) + throws IOException, InterruptedException { + // store the sum + IntDoublePairWritable store = new IntDoublePairWritable(key.get(), row.get().zSum()); + context.write(NullWritable.get(), store); + } + } + + public static class MatrixDiagonalizeReducer + extends Reducer<NullWritable, IntDoublePairWritable, NullWritable, VectorWritable> { + + @Override + protected void reduce(NullWritable key, Iterable<IntDoublePairWritable> values, + Context context) throws IOException, InterruptedException { + // create the return vector + Vector retval = new DenseVector(context.getConfiguration().getInt(Keys.AFFINITY_DIMENSIONS, Integer.MAX_VALUE)); + // put everything in its correct spot + for (IntDoublePairWritable e : values) { + retval.setQuick(e.getKey(), e.getValue()); + } + // write it out + context.write(key, new VectorWritable(retval)); + } + } +}
