http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java new file mode 100644 index 0000000..46fcc7f --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0DocInferenceMapper.java @@ -0,0 +1,51 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +public class CVB0DocInferenceMapper extends CachingCVB0Mapper { + + private final VectorWritable topics = new VectorWritable(); + + @Override + public void map(IntWritable docId, VectorWritable doc, Context context) + throws IOException, InterruptedException { + int numTopics = getNumTopics(); + Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); + Matrix docModel = new SparseRowMatrix(numTopics, doc.get().size()); + int maxIters = getMaxIters(); + ModelTrainer modelTrainer = getModelTrainer(); + for (int i = 0; i < maxIters; i++) { + modelTrainer.getReadModel().trainDocTopicModel(doc.get(), docTopics, docModel); + } + topics.set(docTopics); + context.write(docId, topics); + } + + @Override + protected void cleanup(Context context) { + getModelTrainer().stop(); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java new file mode 100644 index 0000000..3eee446 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0Driver.java @@ -0,0 +1,536 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import com.google.common.base.Joiner; +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.filecache.DistributedCache; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.mapreduce.VectorSumReducer; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.net.URI; +import java.util.List; + +/** + * See {@link CachingCVB0Mapper} for more details on scalability and room for improvement. + * To try out this LDA implementation without using Hadoop, check out + * {@link InMemoryCollapsedVariationalBayes0}. If you want to do training directly in java code + * with your own main(), then look to {@link ModelTrainer} and {@link TopicModel}. + * + * Usage: {@code ./bin/mahout cvb <i>options</i>} + * <p> + * Valid options include: + * <dl> + * <dt>{@code --input path}</td> + * <dd>Input path for {@code SequenceFile<IntWritable, VectorWritable>} document vectors. See + * {@link org.apache.mahout.vectorizer.SparseVectorsFromSequenceFiles} + * for details on how to generate this input format.</dd> + * <dt>{@code --dictionary path}</dt> + * <dd>Path to dictionary file(s) generated during construction of input document vectors (glob + * expression supported). If set, this data is scanned to determine an appropriate value for option + * {@code --num_terms}.</dd> + * <dt>{@code --output path}</dt> + * <dd>Output path for topic-term distributions.</dd> + * <dt>{@code --doc_topic_output path}</dt> + * <dd>Output path for doc-topic distributions.</dd> + * <dt>{@code --num_topics k}</dt> + * <dd>Number of latent topics.</dd> + * <dt>{@code --num_terms nt}</dt> + * <dd>Number of unique features defined by input document vectors. If option {@code --dictionary} + * is defined and this option is unspecified, term count is calculated from dictionary.</dd> + * <dt>{@code --topic_model_temp_dir path}</dt> + * <dd>Path in which to store model state after each iteration.</dd> + * <dt>{@code --maxIter i}</dt> + * <dd>Maximum number of iterations to perform. If this value is less than or equal to the number of + * iteration states found beneath the path specified by option {@code --topic_model_temp_dir}, no + * further iterations are performed. Instead, output topic-term and doc-topic distributions are + * generated using data from the specified iteration.</dd> + * <dt>{@code --max_doc_topic_iters i}</dt> + * <dd>Maximum number of iterations per doc for p(topic|doc) learning. Defaults to {@code 10}.</dd> + * <dt>{@code --doc_topic_smoothing a}</dt> + * <dd>Smoothing for doc-topic distribution. Defaults to {@code 0.0001}.</dd> + * <dt>{@code --term_topic_smoothing e}</dt> + * <dd>Smoothing for topic-term distribution. Defaults to {@code 0.0001}.</dd> + * <dt>{@code --random_seed seed}</dt> + * <dd>Integer seed for random number generation.</dd> + * <dt>{@code --test_set_percentage p}</dt> + * <dd>Fraction of data to hold out for testing. Defaults to {@code 0.0}.</dd> + * <dt>{@code --iteration_block_size block}</dt> + * <dd>Number of iterations between perplexity checks. Defaults to {@code 10}. This option is + * ignored unless option {@code --test_set_percentage} is greater than zero.</dd> + * </dl> + */ +public class CVB0Driver extends AbstractJob { + private static final Logger log = LoggerFactory.getLogger(CVB0Driver.class); + + public static final String NUM_TOPICS = "num_topics"; + public static final String NUM_TERMS = "num_terms"; + public static final String DOC_TOPIC_SMOOTHING = "doc_topic_smoothing"; + public static final String TERM_TOPIC_SMOOTHING = "term_topic_smoothing"; + public static final String DICTIONARY = "dictionary"; + public static final String DOC_TOPIC_OUTPUT = "doc_topic_output"; + public static final String MODEL_TEMP_DIR = "topic_model_temp_dir"; + public static final String ITERATION_BLOCK_SIZE = "iteration_block_size"; + public static final String RANDOM_SEED = "random_seed"; + public static final String TEST_SET_FRACTION = "test_set_fraction"; + public static final String NUM_TRAIN_THREADS = "num_train_threads"; + public static final String NUM_UPDATE_THREADS = "num_update_threads"; + public static final String MAX_ITERATIONS_PER_DOC = "max_doc_topic_iters"; + public static final String MODEL_WEIGHT = "prev_iter_mult"; + public static final String NUM_REDUCE_TASKS = "num_reduce_tasks"; + public static final String BACKFILL_PERPLEXITY = "backfill_perplexity"; + private static final String MODEL_PATHS = "mahout.lda.cvb.modelPath"; + + private static final double DEFAULT_CONVERGENCE_DELTA = 0; + private static final double DEFAULT_DOC_TOPIC_SMOOTHING = 0.0001; + private static final double DEFAULT_TERM_TOPIC_SMOOTHING = 0.0001; + private static final int DEFAULT_ITERATION_BLOCK_SIZE = 10; + private static final double DEFAULT_TEST_SET_FRACTION = 0; + private static final int DEFAULT_NUM_TRAIN_THREADS = 4; + private static final int DEFAULT_NUM_UPDATE_THREADS = 1; + private static final int DEFAULT_MAX_ITERATIONS_PER_DOC = 10; + private static final int DEFAULT_NUM_REDUCE_TASKS = 10; + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.maxIterationsOption().create()); + addOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION, "cd", "The convergence delta value", + String.valueOf(DEFAULT_CONVERGENCE_DELTA)); + addOption(DefaultOptionCreator.overwriteOption().create()); + + addOption(NUM_TOPICS, "k", "Number of topics to learn", true); + addOption(NUM_TERMS, "nt", "Vocabulary size", false); + addOption(DOC_TOPIC_SMOOTHING, "a", "Smoothing for document/topic distribution", + String.valueOf(DEFAULT_DOC_TOPIC_SMOOTHING)); + addOption(TERM_TOPIC_SMOOTHING, "e", "Smoothing for topic/term distribution", + String.valueOf(DEFAULT_TERM_TOPIC_SMOOTHING)); + addOption(DICTIONARY, "dict", "Path to term-dictionary file(s) (glob expression supported)", false); + addOption(DOC_TOPIC_OUTPUT, "dt", "Output path for the training doc/topic distribution", false); + addOption(MODEL_TEMP_DIR, "mt", "Path to intermediate model path (useful for restarting)", false); + addOption(ITERATION_BLOCK_SIZE, "block", "Number of iterations per perplexity check", + String.valueOf(DEFAULT_ITERATION_BLOCK_SIZE)); + addOption(RANDOM_SEED, "seed", "Random seed", false); + addOption(TEST_SET_FRACTION, "tf", "Fraction of data to hold out for testing", + String.valueOf(DEFAULT_TEST_SET_FRACTION)); + addOption(NUM_TRAIN_THREADS, "ntt", "number of threads per mapper to train with", + String.valueOf(DEFAULT_NUM_TRAIN_THREADS)); + addOption(NUM_UPDATE_THREADS, "nut", "number of threads per mapper to update the model with", + String.valueOf(DEFAULT_NUM_UPDATE_THREADS)); + addOption(MAX_ITERATIONS_PER_DOC, "mipd", "max number of iterations per doc for p(topic|doc) learning", + String.valueOf(DEFAULT_MAX_ITERATIONS_PER_DOC)); + addOption(NUM_REDUCE_TASKS, null, "number of reducers to use during model estimation", + String.valueOf(DEFAULT_NUM_REDUCE_TASKS)); + addOption(buildOption(BACKFILL_PERPLEXITY, null, "enable backfilling of missing perplexity values", false, false, + null)); + + if (parseArguments(args) == null) { + return -1; + } + + int numTopics = Integer.parseInt(getOption(NUM_TOPICS)); + Path inputPath = getInputPath(); + Path topicModelOutputPath = getOutputPath(); + int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); + int iterationBlockSize = Integer.parseInt(getOption(ITERATION_BLOCK_SIZE)); + double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION)); + double alpha = Double.parseDouble(getOption(DOC_TOPIC_SMOOTHING)); + double eta = Double.parseDouble(getOption(TERM_TOPIC_SMOOTHING)); + int numTrainThreads = Integer.parseInt(getOption(NUM_TRAIN_THREADS)); + int numUpdateThreads = Integer.parseInt(getOption(NUM_UPDATE_THREADS)); + int maxItersPerDoc = Integer.parseInt(getOption(MAX_ITERATIONS_PER_DOC)); + Path dictionaryPath = hasOption(DICTIONARY) ? new Path(getOption(DICTIONARY)) : null; + int numTerms = hasOption(NUM_TERMS) + ? Integer.parseInt(getOption(NUM_TERMS)) + : getNumTerms(getConf(), dictionaryPath); + Path docTopicOutputPath = hasOption(DOC_TOPIC_OUTPUT) ? new Path(getOption(DOC_TOPIC_OUTPUT)) : null; + Path modelTempPath = hasOption(MODEL_TEMP_DIR) + ? new Path(getOption(MODEL_TEMP_DIR)) + : getTempPath("topicModelState"); + long seed = hasOption(RANDOM_SEED) + ? Long.parseLong(getOption(RANDOM_SEED)) + : System.nanoTime() % 10000; + float testFraction = hasOption(TEST_SET_FRACTION) + ? Float.parseFloat(getOption(TEST_SET_FRACTION)) + : 0.0f; + int numReduceTasks = Integer.parseInt(getOption(NUM_REDUCE_TASKS)); + boolean backfillPerplexity = hasOption(BACKFILL_PERPLEXITY); + + return run(getConf(), inputPath, topicModelOutputPath, numTopics, numTerms, alpha, eta, + maxIterations, iterationBlockSize, convergenceDelta, dictionaryPath, docTopicOutputPath, + modelTempPath, seed, testFraction, numTrainThreads, numUpdateThreads, maxItersPerDoc, + numReduceTasks, backfillPerplexity); + } + + private static int getNumTerms(Configuration conf, Path dictionaryPath) throws IOException { + FileSystem fs = dictionaryPath.getFileSystem(conf); + Text key = new Text(); + IntWritable value = new IntWritable(); + int maxTermId = -1; + for (FileStatus stat : fs.globStatus(dictionaryPath)) { + SequenceFile.Reader reader = new SequenceFile.Reader(fs, stat.getPath(), conf); + while (reader.next(key, value)) { + maxTermId = Math.max(maxTermId, value.get()); + } + } + return maxTermId + 1; + } + + public int run(Configuration conf, + Path inputPath, + Path topicModelOutputPath, + int numTopics, + int numTerms, + double alpha, + double eta, + int maxIterations, + int iterationBlockSize, + double convergenceDelta, + Path dictionaryPath, + Path docTopicOutputPath, + Path topicModelStateTempPath, + long randomSeed, + float testFraction, + int numTrainThreads, + int numUpdateThreads, + int maxItersPerDoc, + int numReduceTasks, + boolean backfillPerplexity) + throws ClassNotFoundException, IOException, InterruptedException { + + setConf(conf); + + // verify arguments + Preconditions.checkArgument(testFraction >= 0.0 && testFraction <= 1.0, + "Expected 'testFraction' value in range [0, 1] but found value '%s'", testFraction); + Preconditions.checkArgument(!backfillPerplexity || testFraction > 0.0, + "Expected 'testFraction' value in range (0, 1] but found value '%s'", testFraction); + + String infoString = "Will run Collapsed Variational Bayes (0th-derivative approximation) " + + "learning for LDA on {} (numTerms: {}), finding {}-topics, with document/topic prior {}, " + + "topic/term prior {}. Maximum iterations to run will be {}, unless the change in " + + "perplexity is less than {}. Topic model output (p(term|topic) for each topic) will be " + + "stored {}. Random initialization seed is {}, holding out {} of the data for perplexity " + + "check\n"; + log.info(infoString, inputPath, numTerms, numTopics, alpha, eta, maxIterations, + convergenceDelta, topicModelOutputPath, randomSeed, testFraction); + infoString = dictionaryPath == null + ? "" : "Dictionary to be used located " + dictionaryPath.toString() + '\n'; + infoString += docTopicOutputPath == null + ? "" : "p(topic|docId) will be stored " + docTopicOutputPath.toString() + '\n'; + log.info(infoString); + + FileSystem fs = FileSystem.get(topicModelStateTempPath.toUri(), conf); + int iterationNumber = getCurrentIterationNumber(conf, topicModelStateTempPath, maxIterations); + log.info("Current iteration number: {}", iterationNumber); + + conf.set(NUM_TOPICS, String.valueOf(numTopics)); + conf.set(NUM_TERMS, String.valueOf(numTerms)); + conf.set(DOC_TOPIC_SMOOTHING, String.valueOf(alpha)); + conf.set(TERM_TOPIC_SMOOTHING, String.valueOf(eta)); + conf.set(RANDOM_SEED, String.valueOf(randomSeed)); + conf.set(NUM_TRAIN_THREADS, String.valueOf(numTrainThreads)); + conf.set(NUM_UPDATE_THREADS, String.valueOf(numUpdateThreads)); + conf.set(MAX_ITERATIONS_PER_DOC, String.valueOf(maxItersPerDoc)); + conf.set(MODEL_WEIGHT, "1"); // TODO + conf.set(TEST_SET_FRACTION, String.valueOf(testFraction)); + + List<Double> perplexities = Lists.newArrayList(); + for (int i = 1; i <= iterationNumber; i++) { + // form path to model + Path modelPath = modelPath(topicModelStateTempPath, i); + + // read perplexity + double perplexity = readPerplexity(conf, topicModelStateTempPath, i); + if (Double.isNaN(perplexity)) { + if (!(backfillPerplexity && i % iterationBlockSize == 0)) { + continue; + } + log.info("Backfilling perplexity at iteration {}", i); + if (!fs.exists(modelPath)) { + log.error("Model path '{}' does not exist; Skipping iteration {} perplexity calculation", + modelPath.toString(), i); + continue; + } + perplexity = calculatePerplexity(conf, inputPath, modelPath, i); + } + + // register and log perplexity + perplexities.add(perplexity); + log.info("Perplexity at iteration {} = {}", i, perplexity); + } + + long startTime = System.currentTimeMillis(); + while (iterationNumber < maxIterations) { + // test convergence + if (convergenceDelta > 0.0) { + double delta = rateOfChange(perplexities); + if (delta < convergenceDelta) { + log.info("Convergence achieved at iteration {} with perplexity {} and delta {}", + iterationNumber, perplexities.get(perplexities.size() - 1), delta); + break; + } + } + + // update model + iterationNumber++; + log.info("About to run iteration {} of {}", iterationNumber, maxIterations); + Path modelInputPath = modelPath(topicModelStateTempPath, iterationNumber - 1); + Path modelOutputPath = modelPath(topicModelStateTempPath, iterationNumber); + runIteration(conf, inputPath, modelInputPath, modelOutputPath, iterationNumber, + maxIterations, numReduceTasks); + + // calculate perplexity + if (testFraction > 0 && iterationNumber % iterationBlockSize == 0) { + perplexities.add(calculatePerplexity(conf, inputPath, modelOutputPath, iterationNumber)); + log.info("Current perplexity = {}", perplexities.get(perplexities.size() - 1)); + log.info("(p_{} - p_{}) / p_0 = {}; target = {}", iterationNumber, iterationNumber - iterationBlockSize, + rateOfChange(perplexities), convergenceDelta); + } + } + log.info("Completed {} iterations in {} seconds", iterationNumber, + (System.currentTimeMillis() - startTime) / 1000); + log.info("Perplexities: ({})", Joiner.on(", ").join(perplexities)); + + // write final topic-term and doc-topic distributions + Path finalIterationData = modelPath(topicModelStateTempPath, iterationNumber); + Job topicModelOutputJob = topicModelOutputPath != null + ? writeTopicModel(conf, finalIterationData, topicModelOutputPath) + : null; + Job docInferenceJob = docTopicOutputPath != null + ? writeDocTopicInference(conf, inputPath, finalIterationData, docTopicOutputPath) + : null; + if (topicModelOutputJob != null && !topicModelOutputJob.waitForCompletion(true)) { + return -1; + } + if (docInferenceJob != null && !docInferenceJob.waitForCompletion(true)) { + return -1; + } + return 0; + } + + private static double rateOfChange(List<Double> perplexities) { + int sz = perplexities.size(); + if (sz < 2) { + return Double.MAX_VALUE; + } + return Math.abs(perplexities.get(sz - 1) - perplexities.get(sz - 2)) / perplexities.get(0); + } + + private double calculatePerplexity(Configuration conf, Path corpusPath, Path modelPath, int iteration) + throws IOException, ClassNotFoundException, InterruptedException { + String jobName = "Calculating perplexity for " + modelPath; + log.info("About to run: {}", jobName); + + Path outputPath = perplexityPath(modelPath.getParent(), iteration); + Job job = prepareJob(corpusPath, outputPath, CachingCVB0PerplexityMapper.class, DoubleWritable.class, + DoubleWritable.class, DualDoubleSumReducer.class, DoubleWritable.class, DoubleWritable.class); + + job.setJobName(jobName); + job.setCombinerClass(DualDoubleSumReducer.class); + job.setNumReduceTasks(1); + setModelPaths(job, modelPath); + HadoopUtil.delete(conf, outputPath); + if (!job.waitForCompletion(true)) { + throw new InterruptedException("Failed to calculate perplexity for: " + modelPath); + } + return readPerplexity(conf, modelPath.getParent(), iteration); + } + + /** + * Sums keys and values independently. + */ + public static class DualDoubleSumReducer extends + Reducer<DoubleWritable, DoubleWritable, DoubleWritable, DoubleWritable> { + private final DoubleWritable outKey = new DoubleWritable(); + private final DoubleWritable outValue = new DoubleWritable(); + + @Override + public void run(Context context) throws IOException, + InterruptedException { + double keySum = 0.0; + double valueSum = 0.0; + while (context.nextKey()) { + keySum += context.getCurrentKey().get(); + for (DoubleWritable value : context.getValues()) { + valueSum += value.get(); + } + } + outKey.set(keySum); + outValue.set(valueSum); + context.write(outKey, outValue); + } + } + + /** + * @param topicModelStateTemp + * @param iteration + * @return {@code double[2]} where first value is perplexity and second is model weight of those + * documents sampled during perplexity computation, or {@code null} if no perplexity data + * exists for the given iteration. + * @throws IOException + */ + public static double readPerplexity(Configuration conf, Path topicModelStateTemp, int iteration) + throws IOException { + Path perplexityPath = perplexityPath(topicModelStateTemp, iteration); + FileSystem fs = FileSystem.get(perplexityPath.toUri(), conf); + if (!fs.exists(perplexityPath)) { + log.warn("Perplexity path {} does not exist, returning NaN", perplexityPath); + return Double.NaN; + } + double perplexity = 0; + double modelWeight = 0; + long n = 0; + for (Pair<DoubleWritable, DoubleWritable> pair : new SequenceFileDirIterable<DoubleWritable, DoubleWritable>( + perplexityPath, PathType.LIST, PathFilters.partFilter(), null, true, conf)) { + modelWeight += pair.getFirst().get(); + perplexity += pair.getSecond().get(); + n++; + } + log.info("Read {} entries with total perplexity {} and model weight {}", n, + perplexity, modelWeight); + return perplexity / modelWeight; + } + + private Job writeTopicModel(Configuration conf, Path modelInput, Path output) + throws IOException, InterruptedException, ClassNotFoundException { + String jobName = String.format("Writing final topic/term distributions from %s to %s", modelInput, output); + log.info("About to run: {}", jobName); + + Job job = prepareJob(modelInput, output, SequenceFileInputFormat.class, CVB0TopicTermVectorNormalizerMapper.class, + IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName); + job.submit(); + return job; + } + + private Job writeDocTopicInference(Configuration conf, Path corpus, Path modelInput, Path output) + throws IOException, ClassNotFoundException, InterruptedException { + String jobName = String.format("Writing final document/topic inference from %s to %s", corpus, output); + log.info("About to run: {}", jobName); + + Job job = prepareJob(corpus, output, SequenceFileInputFormat.class, CVB0DocInferenceMapper.class, + IntWritable.class, VectorWritable.class, SequenceFileOutputFormat.class, jobName); + + FileSystem fs = FileSystem.get(corpus.toUri(), conf); + if (modelInput != null && fs.exists(modelInput)) { + FileStatus[] statuses = fs.listStatus(modelInput, PathFilters.partFilter()); + URI[] modelUris = new URI[statuses.length]; + for (int i = 0; i < statuses.length; i++) { + modelUris[i] = statuses[i].getPath().toUri(); + } + DistributedCache.setCacheFiles(modelUris, conf); + setModelPaths(job, modelInput); + } + job.submit(); + return job; + } + + public static Path modelPath(Path topicModelStateTempPath, int iterationNumber) { + return new Path(topicModelStateTempPath, "model-" + iterationNumber); + } + + public static Path perplexityPath(Path topicModelStateTempPath, int iterationNumber) { + return new Path(topicModelStateTempPath, "perplexity-" + iterationNumber); + } + + private static int getCurrentIterationNumber(Configuration config, Path modelTempDir, int maxIterations) + throws IOException { + FileSystem fs = FileSystem.get(modelTempDir.toUri(), config); + int iterationNumber = 1; + Path iterationPath = modelPath(modelTempDir, iterationNumber); + while (fs.exists(iterationPath) && iterationNumber <= maxIterations) { + log.info("Found previous state: {}", iterationPath); + iterationNumber++; + iterationPath = modelPath(modelTempDir, iterationNumber); + } + return iterationNumber - 1; + } + + public void runIteration(Configuration conf, Path corpusInput, Path modelInput, Path modelOutput, + int iterationNumber, int maxIterations, int numReduceTasks) + throws IOException, ClassNotFoundException, InterruptedException { + String jobName = String.format("Iteration %d of %d, input path: %s", + iterationNumber, maxIterations, modelInput); + log.info("About to run: {}", jobName); + Job job = prepareJob(corpusInput, modelOutput, CachingCVB0Mapper.class, IntWritable.class, VectorWritable.class, + VectorSumReducer.class, IntWritable.class, VectorWritable.class); + job.setCombinerClass(VectorSumReducer.class); + job.setNumReduceTasks(numReduceTasks); + job.setJobName(jobName); + setModelPaths(job, modelInput); + HadoopUtil.delete(conf, modelOutput); + if (!job.waitForCompletion(true)) { + throw new InterruptedException(String.format("Failed to complete iteration %d stage 1", + iterationNumber)); + } + } + + private static void setModelPaths(Job job, Path modelPath) throws IOException { + Configuration conf = job.getConfiguration(); + if (modelPath == null || !FileSystem.get(modelPath.toUri(), conf).exists(modelPath)) { + return; + } + FileStatus[] statuses = FileSystem.get(modelPath.toUri(), conf).listStatus(modelPath, PathFilters.partFilter()); + Preconditions.checkState(statuses.length > 0, "No part files found in model path '%s'", modelPath.toString()); + String[] modelPaths = new String[statuses.length]; + for (int i = 0; i < statuses.length; i++) { + modelPaths[i] = statuses[i].getPath().toUri().toString(); + } + conf.setStrings(MODEL_PATHS, modelPaths); + } + + public static Path[] getModelPaths(Configuration conf) { + String[] modelPathNames = conf.getStrings(MODEL_PATHS); + if (modelPathNames == null || modelPathNames.length == 0) { + return null; + } + Path[] modelPaths = new Path[modelPathNames.length]; + for (int i = 0; i < modelPathNames.length; i++) { + modelPaths[i] = new Path(modelPathNames[i]); + } + return modelPaths; + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new CVB0Driver(), args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java new file mode 100644 index 0000000..1253942 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CVB0TopicTermVectorNormalizerMapper.java @@ -0,0 +1,38 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +import java.io.IOException; + +/** + * Performs L1 normalization of input vectors. + */ +public class CVB0TopicTermVectorNormalizerMapper extends + Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { + + @Override + protected void map(IntWritable key, VectorWritable value, Context context) throws IOException, + InterruptedException { + value.get().assign(Functions.div(value.get().norm(1.0))); + context.write(key, value); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java new file mode 100644 index 0000000..96f36d4 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0Mapper.java @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Run ensemble learning via loading the {@link ModelTrainer} with two {@link TopicModel} instances: + * one from the previous iteration, the other empty. Inference is done on the first, and the + * learning updates are stored in the second, and only emitted at cleanup(). + * <p/> + * In terms of obvious performance improvements still available, the memory footprint in this + * Mapper could be dropped by half if we accumulated model updates onto the model we're using + * for inference, which might also speed up convergence, as we'd be able to take advantage of + * learning <em>during</em> iteration, not just after each one is done. Most likely we don't + * really need to accumulate double values in the model either, floats would most likely be + * sufficient. Between these two, we could squeeze another factor of 4 in memory efficiency. + * <p/> + * In terms of CPU, we're re-learning the p(topic|doc) distribution on every iteration, starting + * from scratch. This is usually only 10 fixed-point iterations per doc, but that's 10x more than + * only 1. To avoid having to do this, we would need to do a map-side join of the unchanging + * corpus with the continually-improving p(topic|doc) matrix, and then emit multiple outputs + * from the mappers to make sure we can do the reduce model averaging as well. Tricky, but + * possibly worth it. + * <p/> + * {@link ModelTrainer} already takes advantage (in maybe the not-nice way) of multi-core + * availability by doing multithreaded learning, see that class for details. + */ +public class CachingCVB0Mapper + extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { + + private static final Logger log = LoggerFactory.getLogger(CachingCVB0Mapper.class); + + private ModelTrainer modelTrainer; + private TopicModel readModel; + private TopicModel writeModel; + private int maxIters; + private int numTopics; + + protected ModelTrainer getModelTrainer() { + return modelTrainer; + } + + protected int getMaxIters() { + return maxIters; + } + + protected int getNumTopics() { + return numTopics; + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + log.info("Retrieving configuration"); + Configuration conf = context.getConfiguration(); + float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); + float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); + long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); + numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); + int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); + int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); + int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); + maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); + float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); + + log.info("Initializing read model"); + Path[] modelPaths = CVB0Driver.getModelPaths(conf); + if (modelPaths != null && modelPaths.length > 0) { + readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); + } else { + log.info("No model files found"); + readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, + numTrainThreads, modelWeight); + } + + log.info("Initializing write model"); + writeModel = modelWeight == 1 + ? new TopicModel(numTopics, numTerms, eta, alpha, null, numUpdateThreads) + : readModel; + + log.info("Initializing model trainer"); + modelTrainer = new ModelTrainer(readModel, writeModel, numTrainThreads, numTopics, numTerms); + modelTrainer.start(); + } + + @Override + public void map(IntWritable docId, VectorWritable document, Context context) + throws IOException, InterruptedException { + /* where to get docTopics? */ + Vector topicVector = new DenseVector(numTopics).assign(1.0 / numTopics); + modelTrainer.train(document.get(), topicVector, true, maxIters); + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + log.info("Stopping model trainer"); + modelTrainer.stop(); + + log.info("Writing model"); + TopicModel readFrom = modelTrainer.getReadModel(); + for (MatrixSlice topic : readFrom) { + context.write(new IntWritable(topic.index()), new VectorWritable(topic.vector())); + } + readModel.stop(); + writeModel.stop(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java new file mode 100644 index 0000000..da77baf --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/CachingCVB0PerplexityMapper.java @@ -0,0 +1,108 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MemoryUtil; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Random; + +public class CachingCVB0PerplexityMapper extends + Mapper<IntWritable, VectorWritable, DoubleWritable, DoubleWritable> { + /** + * Hadoop counters for {@link CachingCVB0PerplexityMapper}, to aid in debugging. + */ + public enum Counters { + SAMPLED_DOCUMENTS + } + + private static final Logger log = LoggerFactory.getLogger(CachingCVB0PerplexityMapper.class); + + private ModelTrainer modelTrainer; + private TopicModel readModel; + private int maxIters; + private int numTopics; + private float testFraction; + private Random random; + private Vector topicVector; + private final DoubleWritable outKey = new DoubleWritable(); + private final DoubleWritable outValue = new DoubleWritable(); + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + MemoryUtil.startMemoryLogger(5000); + + log.info("Retrieving configuration"); + Configuration conf = context.getConfiguration(); + float eta = conf.getFloat(CVB0Driver.TERM_TOPIC_SMOOTHING, Float.NaN); + float alpha = conf.getFloat(CVB0Driver.DOC_TOPIC_SMOOTHING, Float.NaN); + long seed = conf.getLong(CVB0Driver.RANDOM_SEED, 1234L); + random = RandomUtils.getRandom(seed); + numTopics = conf.getInt(CVB0Driver.NUM_TOPICS, -1); + int numTerms = conf.getInt(CVB0Driver.NUM_TERMS, -1); + int numUpdateThreads = conf.getInt(CVB0Driver.NUM_UPDATE_THREADS, 1); + int numTrainThreads = conf.getInt(CVB0Driver.NUM_TRAIN_THREADS, 4); + maxIters = conf.getInt(CVB0Driver.MAX_ITERATIONS_PER_DOC, 10); + float modelWeight = conf.getFloat(CVB0Driver.MODEL_WEIGHT, 1.0f); + testFraction = conf.getFloat(CVB0Driver.TEST_SET_FRACTION, 0.1f); + + log.info("Initializing read model"); + Path[] modelPaths = CVB0Driver.getModelPaths(conf); + if (modelPaths != null && modelPaths.length > 0) { + readModel = new TopicModel(conf, eta, alpha, null, numUpdateThreads, modelWeight, modelPaths); + } else { + log.info("No model files found"); + readModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(seed), null, + numTrainThreads, modelWeight); + } + + log.info("Initializing model trainer"); + modelTrainer = new ModelTrainer(readModel, null, numTrainThreads, numTopics, numTerms); + + log.info("Initializing topic vector"); + topicVector = new DenseVector(new double[numTopics]); + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + readModel.stop(); + MemoryUtil.stopMemoryLogger(); + } + + @Override + public void map(IntWritable docId, VectorWritable document, Context context) + throws IOException, InterruptedException { + if (testFraction < 1.0f && random.nextFloat() >= testFraction) { + return; + } + context.getCounter(Counters.SAMPLED_DOCUMENTS).increment(1); + outKey.set(document.get().norm(1)); + outValue.set(modelTrainer.calculatePerplexity(document.get(), topicVector.assign(1.0 / numTopics), maxIters)); + context.write(outKey, outValue); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java new file mode 100644 index 0000000..07ae100 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/InMemoryCollapsedVariationalBayes0.java @@ -0,0 +1,515 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.DistributedRowMatrixWriter; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.NamedVector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +/** + * Runs the same algorithm as {@link CVB0Driver}, but sequentially, in memory. Memory requirements + * are currently: the entire corpus is read into RAM, two copies of the model (each of size + * numTerms * numTopics), and another matrix of size numDocs * numTopics is held in memory + * (to store p(topic|doc) for all docs). + * + * But if all this fits in memory, this can be significantly faster than an iterative MR job. + */ +public class InMemoryCollapsedVariationalBayes0 extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(InMemoryCollapsedVariationalBayes0.class); + + private int numTopics; + private int numTerms; + private int numDocuments; + private double alpha; + private double eta; + //private int minDfCt; + //private double maxDfPct; + private boolean verbose = false; + private String[] terms; // of length numTerms; + private Matrix corpusWeights; // length numDocs; + private double totalCorpusWeight; + private double initialModelCorpusFraction; + private Matrix docTopicCounts; + private int numTrainingThreads; + private int numUpdatingThreads; + private ModelTrainer modelTrainer; + + private InMemoryCollapsedVariationalBayes0() { + // only for main usage + } + + public void setVerbose(boolean verbose) { + this.verbose = verbose; + } + + public InMemoryCollapsedVariationalBayes0(Matrix corpus, + String[] terms, + int numTopics, + double alpha, + double eta, + int numTrainingThreads, + int numUpdatingThreads, + double modelCorpusFraction) { + //this.seed = seed; + this.numTopics = numTopics; + this.alpha = alpha; + this.eta = eta; + //this.minDfCt = 0; + //this.maxDfPct = 1.0f; + corpusWeights = corpus; + numDocuments = corpus.numRows(); + this.terms = terms; + this.initialModelCorpusFraction = modelCorpusFraction; + numTerms = terms != null ? terms.length : corpus.numCols(); + Map<String, Integer> termIdMap = Maps.newHashMap(); + if (terms != null) { + for (int t = 0; t < terms.length; t++) { + termIdMap.put(terms[t], t); + } + } + this.numTrainingThreads = numTrainingThreads; + this.numUpdatingThreads = numUpdatingThreads; + postInitCorpus(); + initializeModel(); + } + + private void postInitCorpus() { + totalCorpusWeight = 0; + int numNonZero = 0; + for (int i = 0; i < numDocuments; i++) { + Vector v = corpusWeights.viewRow(i); + double norm; + if (v != null && (norm = v.norm(1)) != 0) { + numNonZero += v.getNumNondefaultElements(); + totalCorpusWeight += norm; + } + } + String s = "Initializing corpus with %d docs, %d terms, %d nonzero entries, total termWeight %f"; + log.info(String.format(s, numDocuments, numTerms, numNonZero, totalCorpusWeight)); + } + + private void initializeModel() { + TopicModel topicModel = new TopicModel(numTopics, numTerms, eta, alpha, RandomUtils.getRandom(), terms, + numUpdatingThreads, initialModelCorpusFraction == 0 ? 1 : initialModelCorpusFraction * totalCorpusWeight); + topicModel.setConf(getConf()); + + TopicModel updatedModel = initialModelCorpusFraction == 0 + ? new TopicModel(numTopics, numTerms, eta, alpha, null, terms, numUpdatingThreads, 1) + : topicModel; + updatedModel.setConf(getConf()); + docTopicCounts = new DenseMatrix(numDocuments, numTopics); + docTopicCounts.assign(1.0 / numTopics); + modelTrainer = new ModelTrainer(topicModel, updatedModel, numTrainingThreads, numTopics, numTerms); + } + + /* + private void inferDocuments(double convergence, int maxIter, boolean recalculate) { + for (int docId = 0; docId < corpusWeights.numRows() ; docId++) { + Vector inferredDocument = topicModel.infer(corpusWeights.viewRow(docId), + docTopicCounts.viewRow(docId)); + // do what now? + } + } + */ + + public void trainDocuments() { + trainDocuments(0); + } + + public void trainDocuments(double testFraction) { + long start = System.nanoTime(); + modelTrainer.start(); + for (int docId = 0; docId < corpusWeights.numRows(); docId++) { + if (testFraction == 0 || docId % (1 / testFraction) != 0) { + Vector docTopics = new DenseVector(numTopics).assign(1.0 / numTopics); // docTopicCounts.getRow(docId) + modelTrainer.trainSync(corpusWeights.viewRow(docId), docTopics , true, 10); + } + } + modelTrainer.stop(); + logTime("train documents", System.nanoTime() - start); + } + + /* + private double error(int docId) { + Vector docTermCounts = corpusWeights.viewRow(docId); + if (docTermCounts == null) { + return 0; + } else { + Vector expectedDocTermCounts = + topicModel.infer(corpusWeights.viewRow(docId), docTopicCounts.viewRow(docId)); + double expectedNorm = expectedDocTermCounts.norm(1); + return expectedDocTermCounts.times(docTermCounts.norm(1)/expectedNorm) + .minus(docTermCounts).norm(1); + } + } + + private double error() { + long time = System.nanoTime(); + double error = 0; + for (int docId = 0; docId < numDocuments; docId++) { + error += error(docId); + } + logTime("error calculation", System.nanoTime() - time); + return error / totalCorpusWeight; + } + */ + + public double iterateUntilConvergence(double minFractionalErrorChange, + int maxIterations, int minIter) { + return iterateUntilConvergence(minFractionalErrorChange, maxIterations, minIter, 0); + } + + public double iterateUntilConvergence(double minFractionalErrorChange, + int maxIterations, int minIter, double testFraction) { + int iter = 0; + double oldPerplexity = 0; + while (iter < minIter) { + trainDocuments(testFraction); + if (verbose) { + log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); + } + log.info("iteration {} complete", iter); + oldPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, + testFraction); + log.info("{} = perplexity", oldPerplexity); + iter++; + } + double newPerplexity = 0; + double fractionalChange = Double.MAX_VALUE; + while (iter < maxIterations && fractionalChange > minFractionalErrorChange) { + trainDocuments(); + if (verbose) { + log.info("model after: {}: {}", iter, modelTrainer.getReadModel()); + } + newPerplexity = modelTrainer.calculatePerplexity(corpusWeights, docTopicCounts, + testFraction); + log.info("{} = perplexity", newPerplexity); + iter++; + fractionalChange = Math.abs(newPerplexity - oldPerplexity) / oldPerplexity; + log.info("{} = fractionalChange", fractionalChange); + oldPerplexity = newPerplexity; + } + if (iter < maxIterations) { + log.info(String.format("Converged! fractional error change: %f, error %f", + fractionalChange, newPerplexity)); + } else { + log.info(String.format("Reached max iteration count (%d), fractional error change: %f, error: %f", + maxIterations, fractionalChange, newPerplexity)); + } + return newPerplexity; + } + + public void writeModel(Path outputPath) throws IOException { + modelTrainer.persist(outputPath); + } + + private static void logTime(String label, long nanos) { + log.info("{} time: {}ms", label, nanos / 1.0e6); + } + + public static int main2(String[] args, Configuration conf) throws Exception { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option helpOpt = DefaultOptionCreator.helpOption(); + + Option inputDirOpt = obuilder.withLongName("input").withRequired(true).withArgument( + abuilder.withName("input").withMinimum(1).withMaximum(1).create()).withDescription( + "The Directory on HDFS containing the collapsed, properly formatted files having " + + "one doc per line").withShortName("i").create(); + + Option dictOpt = obuilder.withLongName("dictionary").withRequired(false).withArgument( + abuilder.withName("dictionary").withMinimum(1).withMaximum(1).create()).withDescription( + "The path to the term-dictionary format is ... ").withShortName("d").create(); + + Option dfsOpt = obuilder.withLongName("dfs").withRequired(false).withArgument( + abuilder.withName("dfs").withMinimum(1).withMaximum(1).create()).withDescription( + "HDFS namenode URI").withShortName("dfs").create(); + + Option numTopicsOpt = obuilder.withLongName("numTopics").withRequired(true).withArgument(abuilder + .withName("numTopics").withMinimum(1).withMaximum(1) + .create()).withDescription("Number of topics to learn").withShortName("top").create(); + + Option outputTopicFileOpt = obuilder.withLongName("topicOutputFile").withRequired(true).withArgument( + abuilder.withName("topicOutputFile").withMinimum(1).withMaximum(1).create()) + .withDescription("File to write out p(term | topic)").withShortName("to").create(); + + Option outputDocFileOpt = obuilder.withLongName("docOutputFile").withRequired(true).withArgument( + abuilder.withName("docOutputFile").withMinimum(1).withMaximum(1).create()) + .withDescription("File to write out p(topic | docid)").withShortName("do").create(); + + Option alphaOpt = obuilder.withLongName("alpha").withRequired(false).withArgument(abuilder + .withName("alpha").withMinimum(1).withMaximum(1).withDefault("0.1").create()) + .withDescription("Smoothing parameter for p(topic | document) prior").withShortName("a").create(); + + Option etaOpt = obuilder.withLongName("eta").withRequired(false).withArgument(abuilder + .withName("eta").withMinimum(1).withMaximum(1).withDefault("0.1").create()) + .withDescription("Smoothing parameter for p(term | topic)").withShortName("e").create(); + + Option maxIterOpt = obuilder.withLongName("maxIterations").withRequired(false).withArgument(abuilder + .withName("maxIterations").withMinimum(1).withMaximum(1).withDefault("10").create()) + .withDescription("Maximum number of training passes").withShortName("m").create(); + + Option modelCorpusFractionOption = obuilder.withLongName("modelCorpusFraction") + .withRequired(false).withArgument(abuilder.withName("modelCorpusFraction").withMinimum(1) + .withMaximum(1).withDefault("0.0").create()).withShortName("mcf") + .withDescription("For online updates, initial value of |model|/|corpus|").create(); + + Option burnInOpt = obuilder.withLongName("burnInIterations").withRequired(false).withArgument(abuilder + .withName("burnInIterations").withMinimum(1).withMaximum(1).withDefault("5").create()) + .withDescription("Minimum number of iterations").withShortName("b").create(); + + Option convergenceOpt = obuilder.withLongName("convergence").withRequired(false).withArgument(abuilder + .withName("convergence").withMinimum(1).withMaximum(1).withDefault("0.0").create()) + .withDescription("Fractional rate of perplexity to consider convergence").withShortName("c").create(); + + Option reInferDocTopicsOpt = obuilder.withLongName("reInferDocTopics").withRequired(false) + .withArgument(abuilder.withName("reInferDocTopics").withMinimum(1).withMaximum(1) + .withDefault("no").create()) + .withDescription("re-infer p(topic | doc) : [no | randstart | continue]") + .withShortName("rdt").create(); + + Option numTrainThreadsOpt = obuilder.withLongName("numTrainThreads").withRequired(false) + .withArgument(abuilder.withName("numTrainThreads").withMinimum(1).withMaximum(1) + .withDefault("1").create()) + .withDescription("number of threads to train with") + .withShortName("ntt").create(); + + Option numUpdateThreadsOpt = obuilder.withLongName("numUpdateThreads").withRequired(false) + .withArgument(abuilder.withName("numUpdateThreads").withMinimum(1).withMaximum(1) + .withDefault("1").create()) + .withDescription("number of threads to update the model with") + .withShortName("nut").create(); + + Option verboseOpt = obuilder.withLongName("verbose").withRequired(false) + .withArgument(abuilder.withName("verbose").withMinimum(1).withMaximum(1) + .withDefault("false").create()) + .withDescription("print verbose information, like top-terms in each topic, during iteration") + .withShortName("v").create(); + + Group group = gbuilder.withName("Options").withOption(inputDirOpt).withOption(numTopicsOpt) + .withOption(alphaOpt).withOption(etaOpt) + .withOption(maxIterOpt).withOption(burnInOpt).withOption(convergenceOpt) + .withOption(dictOpt).withOption(reInferDocTopicsOpt) + .withOption(outputDocFileOpt).withOption(outputTopicFileOpt).withOption(dfsOpt) + .withOption(numTrainThreadsOpt).withOption(numUpdateThreadsOpt) + .withOption(modelCorpusFractionOption).withOption(verboseOpt).create(); + + try { + Parser parser = new Parser(); + + parser.setGroup(group); + parser.setHelpOption(helpOpt); + CommandLine cmdLine = parser.parse(args); + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return -1; + } + + String inputDirString = (String) cmdLine.getValue(inputDirOpt); + String dictDirString = cmdLine.hasOption(dictOpt) ? (String)cmdLine.getValue(dictOpt) : null; + int numTopics = Integer.parseInt((String) cmdLine.getValue(numTopicsOpt)); + double alpha = Double.parseDouble((String)cmdLine.getValue(alphaOpt)); + double eta = Double.parseDouble((String)cmdLine.getValue(etaOpt)); + int maxIterations = Integer.parseInt((String)cmdLine.getValue(maxIterOpt)); + int burnInIterations = Integer.parseInt((String)cmdLine.getValue(burnInOpt)); + double minFractionalErrorChange = Double.parseDouble((String) cmdLine.getValue(convergenceOpt)); + int numTrainThreads = Integer.parseInt((String)cmdLine.getValue(numTrainThreadsOpt)); + int numUpdateThreads = Integer.parseInt((String)cmdLine.getValue(numUpdateThreadsOpt)); + String topicOutFile = (String)cmdLine.getValue(outputTopicFileOpt); + String docOutFile = (String)cmdLine.getValue(outputDocFileOpt); + //String reInferDocTopics = (String)cmdLine.getValue(reInferDocTopicsOpt); + boolean verbose = Boolean.parseBoolean((String) cmdLine.getValue(verboseOpt)); + double modelCorpusFraction = Double.parseDouble((String)cmdLine.getValue(modelCorpusFractionOption)); + + long start = System.nanoTime(); + + if (conf.get("fs.default.name") == null) { + String dfsNameNode = (String)cmdLine.getValue(dfsOpt); + conf.set("fs.default.name", dfsNameNode); + } + String[] terms = loadDictionary(dictDirString, conf); + logTime("dictionary loading", System.nanoTime() - start); + start = System.nanoTime(); + Matrix corpus = loadVectors(inputDirString, conf); + logTime("vector seqfile corpus loading", System.nanoTime() - start); + start = System.nanoTime(); + InMemoryCollapsedVariationalBayes0 cvb0 = + new InMemoryCollapsedVariationalBayes0(corpus, terms, numTopics, alpha, eta, + numTrainThreads, numUpdateThreads, modelCorpusFraction); + logTime("cvb0 init", System.nanoTime() - start); + + start = System.nanoTime(); + cvb0.setVerbose(verbose); + cvb0.iterateUntilConvergence(minFractionalErrorChange, maxIterations, burnInIterations); + logTime("total training time", System.nanoTime() - start); + + /* + if ("randstart".equalsIgnoreCase(reInferDocTopics)) { + cvb0.inferDocuments(0.0, 100, true); + } else if ("continue".equalsIgnoreCase(reInferDocTopics)) { + cvb0.inferDocuments(0.0, 100, false); + } + */ + + start = System.nanoTime(); + cvb0.writeModel(new Path(topicOutFile)); + DistributedRowMatrixWriter.write(new Path(docOutFile), conf, cvb0.docTopicCounts); + logTime("printTopics", System.nanoTime() - start); + } catch (OptionException e) { + log.error("Error while parsing options", e); + CommandLineUtil.printHelp(group); + } + return 0; + } + + /* + private static Map<Integer, Map<String, Integer>> loadCorpus(String path) throws IOException { + List<String> lines = Resources.readLines(Resources.getResource(path), Charsets.UTF_8); + Map<Integer, Map<String, Integer>> corpus = Maps.newHashMap(); + for (int i=0; i<lines.size(); i++) { + String line = lines.get(i); + Map<String, Integer> doc = Maps.newHashMap(); + for (String s : line.split(" ")) { + s = s.replaceAll("\\W", "").toLowerCase().trim(); + if (s.length() == 0) { + continue; + } + if (!doc.containsKey(s)) { + doc.put(s, 0); + } + doc.put(s, doc.get(s) + 1); + } + corpus.put(i, doc); + } + return corpus; + } + */ + + private static String[] loadDictionary(String dictionaryPath, Configuration conf) { + if (dictionaryPath == null) { + return null; + } + Path dictionaryFile = new Path(dictionaryPath); + List<Pair<Integer, String>> termList = Lists.newArrayList(); + int maxTermId = 0; + // key is word value is id + for (Pair<Writable, IntWritable> record + : new SequenceFileIterable<Writable, IntWritable>(dictionaryFile, true, conf)) { + termList.add(new Pair<>(record.getSecond().get(), + record.getFirst().toString())); + maxTermId = Math.max(maxTermId, record.getSecond().get()); + } + String[] terms = new String[maxTermId + 1]; + for (Pair<Integer, String> pair : termList) { + terms[pair.getFirst()] = pair.getSecond(); + } + return terms; + } + + @Override + public Configuration getConf() { + return super.getConf(); + } + + private static Matrix loadVectors(String vectorPathString, Configuration conf) + throws IOException { + Path vectorPath = new Path(vectorPathString); + FileSystem fs = vectorPath.getFileSystem(conf); + List<Path> subPaths = Lists.newArrayList(); + if (fs.isFile(vectorPath)) { + subPaths.add(vectorPath); + } else { + for (FileStatus fileStatus : fs.listStatus(vectorPath, PathFilters.logsCRCFilter())) { + subPaths.add(fileStatus.getPath()); + } + } + List<Pair<Integer, Vector>> rowList = Lists.newArrayList(); + int numRows = Integer.MIN_VALUE; + int numCols = -1; + boolean sequentialAccess = false; + for (Path subPath : subPaths) { + for (Pair<IntWritable, VectorWritable> record + : new SequenceFileIterable<IntWritable, VectorWritable>(subPath, true, conf)) { + int id = record.getFirst().get(); + Vector vector = record.getSecond().get(); + if (vector instanceof NamedVector) { + vector = ((NamedVector)vector).getDelegate(); + } + if (numCols < 0) { + numCols = vector.size(); + sequentialAccess = vector.isSequentialAccess(); + } + rowList.add(Pair.of(id, vector)); + numRows = Math.max(numRows, id); + } + } + numRows++; + Vector[] rowVectors = new Vector[numRows]; + for (Pair<Integer, Vector> pair : rowList) { + rowVectors[pair.getFirst()] = pair.getSecond(); + } + return new SparseRowMatrix(numRows, numCols, rowVectors, true, !sequentialAccess); + + } + + @Override + public int run(String[] strings) throws Exception { + return main2(strings, getConf()); + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new InMemoryCollapsedVariationalBayes0(), args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java new file mode 100644 index 0000000..912b6d5 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/lda/cvb/ModelTrainer.java @@ -0,0 +1,301 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.lda.cvb; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorIterable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +/** + * Multithreaded LDA model trainer class, which primarily operates by running a "map/reduce" + * operation, all in memory locally (ie not a hadoop job!) : the "map" operation is to take + * the "read-only" {@link TopicModel} and use it to iteratively learn the p(topic|term, doc) + * distribution for documents (this can be done in parallel across many documents, as the + * "read-only" model is, well, read-only. Then the outputs of this are "reduced" onto the + * "write" model, and these updates are not parallelizable in the same way: individual + * documents can't be added to the same entries in different threads at the same time, but + * updates across many topics to the same term from the same document can be done in parallel, + * so they are. + * + * Because computation is done asynchronously, when iteration is done, it's important to call + * the stop() method, which blocks until work is complete. + * + * Setting the read model and the write model to be the same object may not quite work yet, + * on account of parallelism badness. + */ +public class ModelTrainer { + + private static final Logger log = LoggerFactory.getLogger(ModelTrainer.class); + + private final int numTopics; + private final int numTerms; + private TopicModel readModel; + private TopicModel writeModel; + private ThreadPoolExecutor threadPool; + private BlockingQueue<Runnable> workQueue; + private final int numTrainThreads; + private final boolean isReadWrite; + + public ModelTrainer(TopicModel initialReadModel, TopicModel initialWriteModel, + int numTrainThreads, int numTopics, int numTerms) { + this.readModel = initialReadModel; + this.writeModel = initialWriteModel; + this.numTrainThreads = numTrainThreads; + this.numTopics = numTopics; + this.numTerms = numTerms; + isReadWrite = initialReadModel == initialWriteModel; + } + + /** + * WARNING: this constructor may not lead to good behavior. What should be verified is that + * the model updating process does not conflict with model reading. It might work, but then + * again, it might not! + * @param model to be used for both reading (inference) and accumulating (learning) + * @param numTrainThreads + * @param numTopics + * @param numTerms + */ + public ModelTrainer(TopicModel model, int numTrainThreads, int numTopics, int numTerms) { + this(model, model, numTrainThreads, numTopics, numTerms); + } + + public TopicModel getReadModel() { + return readModel; + } + + public void start() { + log.info("Starting training threadpool with {} threads", numTrainThreads); + workQueue = new ArrayBlockingQueue<>(numTrainThreads * 10); + threadPool = new ThreadPoolExecutor(numTrainThreads, numTrainThreads, 0, TimeUnit.SECONDS, + workQueue); + threadPool.allowCoreThreadTimeOut(false); + threadPool.prestartAllCoreThreads(); + writeModel.reset(); + } + + public void train(VectorIterable matrix, VectorIterable docTopicCounts) { + train(matrix, docTopicCounts, 1); + } + + public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts) { + return calculatePerplexity(matrix, docTopicCounts, 0); + } + + public double calculatePerplexity(VectorIterable matrix, VectorIterable docTopicCounts, + double testFraction) { + Iterator<MatrixSlice> docIterator = matrix.iterator(); + Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator(); + double perplexity = 0; + double matrixNorm = 0; + while (docIterator.hasNext() && docTopicIterator.hasNext()) { + MatrixSlice docSlice = docIterator.next(); + MatrixSlice topicSlice = docTopicIterator.next(); + int docId = docSlice.index(); + Vector document = docSlice.vector(); + Vector topicDist = topicSlice.vector(); + if (testFraction == 0 || docId % (1 / testFraction) == 0) { + trainSync(document, topicDist, false, 10); + perplexity += readModel.perplexity(document, topicDist); + matrixNorm += document.norm(1); + } + } + return perplexity / matrixNorm; + } + + public void train(VectorIterable matrix, VectorIterable docTopicCounts, int numDocTopicIters) { + start(); + Iterator<MatrixSlice> docIterator = matrix.iterator(); + Iterator<MatrixSlice> docTopicIterator = docTopicCounts.iterator(); + long startTime = System.nanoTime(); + int i = 0; + double[] times = new double[100]; + Map<Vector, Vector> batch = Maps.newHashMap(); + int numTokensInBatch = 0; + long batchStart = System.nanoTime(); + while (docIterator.hasNext() && docTopicIterator.hasNext()) { + i++; + Vector document = docIterator.next().vector(); + Vector topicDist = docTopicIterator.next().vector(); + if (isReadWrite) { + if (batch.size() < numTrainThreads) { + batch.put(document, topicDist); + if (log.isDebugEnabled()) { + numTokensInBatch += document.getNumNondefaultElements(); + } + } else { + batchTrain(batch, true, numDocTopicIters); + long time = System.nanoTime(); + log.debug("trained {} docs with {} tokens, start time {}, end time {}", + numTrainThreads, numTokensInBatch, batchStart, time); + batchStart = time; + numTokensInBatch = 0; + } + } else { + long start = System.nanoTime(); + train(document, topicDist, true, numDocTopicIters); + if (log.isDebugEnabled()) { + times[i % times.length] = + (System.nanoTime() - start) / (1.0e6 * document.getNumNondefaultElements()); + if (i % 100 == 0) { + long time = System.nanoTime() - startTime; + log.debug("trained {} documents in {}ms", i, time / 1.0e6); + if (i % 500 == 0) { + Arrays.sort(times); + log.debug("training took median {}ms per token-instance", times[times.length / 2]); + } + } + } + } + } + stop(); + } + + public void batchTrain(Map<Vector, Vector> batch, boolean update, int numDocTopicsIters) { + while (true) { + try { + List<TrainerRunnable> runnables = Lists.newArrayList(); + for (Map.Entry<Vector, Vector> entry : batch.entrySet()) { + runnables.add(new TrainerRunnable(readModel, null, entry.getKey(), + entry.getValue(), new SparseRowMatrix(numTopics, numTerms, true), + numDocTopicsIters)); + } + threadPool.invokeAll(runnables); + if (update) { + for (TrainerRunnable runnable : runnables) { + writeModel.update(runnable.docTopicModel); + } + } + break; + } catch (InterruptedException e) { + log.warn("Interrupted during batch training, retrying!", e); + } + } + } + + public void train(Vector document, Vector docTopicCounts, boolean update, int numDocTopicIters) { + while (true) { + try { + workQueue.put(new TrainerRunnable(readModel, update + ? writeModel + : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters)); + return; + } catch (InterruptedException e) { + log.warn("Interrupted waiting to submit document to work queue: {}", document, e); + } + } + } + + public void trainSync(Vector document, Vector docTopicCounts, boolean update, + int numDocTopicIters) { + new TrainerRunnable(readModel, update + ? writeModel + : null, document, docTopicCounts, new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters).run(); + } + + public double calculatePerplexity(Vector document, Vector docTopicCounts, int numDocTopicIters) { + TrainerRunnable runner = new TrainerRunnable(readModel, null, document, docTopicCounts, + new SparseRowMatrix(numTopics, numTerms, true), numDocTopicIters); + return runner.call(); + } + + public void stop() { + long startTime = System.nanoTime(); + log.info("Initiating stopping of training threadpool"); + try { + threadPool.shutdown(); + if (!threadPool.awaitTermination(60, TimeUnit.SECONDS)) { + log.warn("Threadpool timed out on await termination - jobs still running!"); + } + long newTime = System.nanoTime(); + log.info("threadpool took: {}ms", (newTime - startTime) / 1.0e6); + startTime = newTime; + readModel.stop(); + newTime = System.nanoTime(); + log.info("readModel.stop() took {}ms", (newTime - startTime) / 1.0e6); + startTime = newTime; + writeModel.stop(); + newTime = System.nanoTime(); + log.info("writeModel.stop() took {}ms", (newTime - startTime) / 1.0e6); + TopicModel tmpModel = writeModel; + writeModel = readModel; + readModel = tmpModel; + } catch (InterruptedException e) { + log.error("Interrupted shutting down!", e); + } + } + + public void persist(Path outputPath) throws IOException { + readModel.persist(outputPath, true); + } + + private static final class TrainerRunnable implements Runnable, Callable<Double> { + private final TopicModel readModel; + private final TopicModel writeModel; + private final Vector document; + private final Vector docTopics; + private final Matrix docTopicModel; + private final int numDocTopicIters; + + private TrainerRunnable(TopicModel readModel, TopicModel writeModel, Vector document, + Vector docTopics, Matrix docTopicModel, int numDocTopicIters) { + this.readModel = readModel; + this.writeModel = writeModel; + this.document = document; + this.docTopics = docTopics; + this.docTopicModel = docTopicModel; + this.numDocTopicIters = numDocTopicIters; + } + + @Override + public void run() { + for (int i = 0; i < numDocTopicIters; i++) { + // synchronous read-only call: + readModel.trainDocTopicModel(document, docTopics, docTopicModel); + } + if (writeModel != null) { + // parallel call which is read-only on the docTopicModel, and write-only on the writeModel + // this method does not return until all rows of the docTopicModel have been submitted + // to write work queues + writeModel.update(docTopicModel); + } + } + + @Override + public Double call() { + run(); + return readModel.perplexity(document, docTopics); + } + } +}
