http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java new file mode 100644 index 0000000..ee56124 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/SparseVectorsFromSequenceFiles.java @@ -0,0 +1,369 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.lucene.AnalyzerUtils; +import org.apache.mahout.math.hadoop.stats.BasicStats; +import org.apache.mahout.vectorizer.collocations.llr.LLRReducer; +import org.apache.mahout.vectorizer.common.PartialVectorMerger; +import org.apache.mahout.vectorizer.tfidf.TFIDFConverter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.util.List; + +/** + * Converts a given set of sequence files into SparseVectors + */ +public final class SparseVectorsFromSequenceFiles extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(SparseVectorsFromSequenceFiles.class); + + public static void main(String[] args) throws Exception { + ToolRunner.run(new SparseVectorsFromSequenceFiles(), args); + } + + @Override + public int run(String[] args) throws Exception { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option inputDirOpt = DefaultOptionCreator.inputOption().create(); + + Option outputDirOpt = DefaultOptionCreator.outputOption().create(); + + Option minSupportOpt = obuilder.withLongName("minSupport").withArgument( + abuilder.withName("minSupport").withMinimum(1).withMaximum(1).create()).withDescription( + "(Optional) Minimum Support. Default Value: 2").withShortName("s").create(); + + Option analyzerNameOpt = obuilder.withLongName("analyzerName").withArgument( + abuilder.withName("analyzerName").withMinimum(1).withMaximum(1).create()).withDescription( + "The class name of the analyzer").withShortName("a").create(); + + Option chunkSizeOpt = obuilder.withLongName("chunkSize").withArgument( + abuilder.withName("chunkSize").withMinimum(1).withMaximum(1).create()).withDescription( + "The chunkSize in MegaBytes. Default Value: 100MB").withShortName("chunk").create(); + + Option weightOpt = obuilder.withLongName("weight").withRequired(false).withArgument( + abuilder.withName("weight").withMinimum(1).withMaximum(1).create()).withDescription( + "The kind of weight to use. Currently TF or TFIDF. Default: TFIDF").withShortName("wt").create(); + + Option minDFOpt = obuilder.withLongName("minDF").withRequired(false).withArgument( + abuilder.withName("minDF").withMinimum(1).withMaximum(1).create()).withDescription( + "The minimum document frequency. Default is 1").withShortName("md").create(); + + Option maxDFPercentOpt = obuilder.withLongName("maxDFPercent").withRequired(false).withArgument( + abuilder.withName("maxDFPercent").withMinimum(1).withMaximum(1).create()).withDescription( + "The max percentage of docs for the DF. Can be used to remove really high frequency terms." + + " Expressed as an integer between 0 and 100. Default is 99. If maxDFSigma is also set, " + + "it will override this value.").withShortName("x").create(); + + Option maxDFSigmaOpt = obuilder.withLongName("maxDFSigma").withRequired(false).withArgument( + abuilder.withName("maxDFSigma").withMinimum(1).withMaximum(1).create()).withDescription( + "What portion of the tf (tf-idf) vectors to be used, expressed in times the standard deviation (sigma) " + + "of the document frequencies of these vectors. Can be used to remove really high frequency terms." + + " Expressed as a double value. Good value to be specified is 3.0. In case the value is less " + + "than 0 no vectors will be filtered out. Default is -1.0. Overrides maxDFPercent") + .withShortName("xs").create(); + + Option minLLROpt = obuilder.withLongName("minLLR").withRequired(false).withArgument( + abuilder.withName("minLLR").withMinimum(1).withMaximum(1).create()).withDescription( + "(Optional)The minimum Log Likelihood Ratio(Float) Default is " + LLRReducer.DEFAULT_MIN_LLR) + .withShortName("ml").create(); + + Option numReduceTasksOpt = obuilder.withLongName("numReducers").withArgument( + abuilder.withName("numReducers").withMinimum(1).withMaximum(1).create()).withDescription( + "(Optional) Number of reduce tasks. Default Value: 1").withShortName("nr").create(); + + Option powerOpt = obuilder.withLongName("norm").withRequired(false).withArgument( + abuilder.withName("norm").withMinimum(1).withMaximum(1).create()).withDescription( + "The norm to use, expressed as either a float or \"INF\" if you want to use the Infinite norm. " + + "Must be greater or equal to 0. The default is not to normalize").withShortName("n").create(); + + Option logNormalizeOpt = obuilder.withLongName("logNormalize").withRequired(false) + .withDescription( + "(Optional) Whether output vectors should be logNormalize. If set true else false") + .withShortName("lnorm").create(); + + Option maxNGramSizeOpt = obuilder.withLongName("maxNGramSize").withRequired(false).withArgument( + abuilder.withName("ngramSize").withMinimum(1).withMaximum(1).create()) + .withDescription( + "(Optional) The maximum size of ngrams to create" + + " (2 = bigrams, 3 = trigrams, etc) Default Value:1").withShortName("ng").create(); + + Option sequentialAccessVectorOpt = obuilder.withLongName("sequentialAccessVector").withRequired(false) + .withDescription( + "(Optional) Whether output vectors should be SequentialAccessVectors. If set true else false") + .withShortName("seq").create(); + + Option namedVectorOpt = obuilder.withLongName("namedVector").withRequired(false) + .withDescription( + "(Optional) Whether output vectors should be NamedVectors. If set true else false") + .withShortName("nv").create(); + + Option overwriteOutput = obuilder.withLongName("overwrite").withRequired(false).withDescription( + "If set, overwrite the output directory").withShortName("ow").create(); + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(minSupportOpt).withOption(analyzerNameOpt) + .withOption(chunkSizeOpt).withOption(outputDirOpt).withOption(inputDirOpt).withOption(minDFOpt) + .withOption(maxDFSigmaOpt).withOption(maxDFPercentOpt).withOption(weightOpt).withOption(powerOpt) + .withOption(minLLROpt).withOption(numReduceTasksOpt).withOption(maxNGramSizeOpt).withOption(overwriteOutput) + .withOption(helpOpt).withOption(sequentialAccessVectorOpt).withOption(namedVectorOpt) + .withOption(logNormalizeOpt) + .create(); + try { + Parser parser = new Parser(); + parser.setGroup(group); + parser.setHelpOption(helpOpt); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return -1; + } + + Path inputDir = new Path((String) cmdLine.getValue(inputDirOpt)); + Path outputDir = new Path((String) cmdLine.getValue(outputDirOpt)); + + int chunkSize = 100; + if (cmdLine.hasOption(chunkSizeOpt)) { + chunkSize = Integer.parseInt((String) cmdLine.getValue(chunkSizeOpt)); + } + int minSupport = 2; + if (cmdLine.hasOption(minSupportOpt)) { + String minSupportString = (String) cmdLine.getValue(minSupportOpt); + minSupport = Integer.parseInt(minSupportString); + } + + int maxNGramSize = 1; + + if (cmdLine.hasOption(maxNGramSizeOpt)) { + try { + maxNGramSize = Integer.parseInt(cmdLine.getValue(maxNGramSizeOpt).toString()); + } catch (NumberFormatException ex) { + log.warn("Could not parse ngram size option"); + } + } + log.info("Maximum n-gram size is: {}", maxNGramSize); + + if (cmdLine.hasOption(overwriteOutput)) { + HadoopUtil.delete(getConf(), outputDir); + } + + float minLLRValue = LLRReducer.DEFAULT_MIN_LLR; + if (cmdLine.hasOption(minLLROpt)) { + minLLRValue = Float.parseFloat(cmdLine.getValue(minLLROpt).toString()); + } + log.info("Minimum LLR value: {}", minLLRValue); + + int reduceTasks = 1; + if (cmdLine.hasOption(numReduceTasksOpt)) { + reduceTasks = Integer.parseInt(cmdLine.getValue(numReduceTasksOpt).toString()); + } + log.info("Number of reduce tasks: {}", reduceTasks); + + Class<? extends Analyzer> analyzerClass = StandardAnalyzer.class; + if (cmdLine.hasOption(analyzerNameOpt)) { + String className = cmdLine.getValue(analyzerNameOpt).toString(); + analyzerClass = Class.forName(className).asSubclass(Analyzer.class); + // try instantiating it, b/c there isn't any point in setting it if + // you can't instantiate it + AnalyzerUtils.createAnalyzer(analyzerClass); + } + + boolean processIdf; + + if (cmdLine.hasOption(weightOpt)) { + String wString = cmdLine.getValue(weightOpt).toString(); + if ("tf".equalsIgnoreCase(wString)) { + processIdf = false; + } else if ("tfidf".equalsIgnoreCase(wString)) { + processIdf = true; + } else { + throw new OptionException(weightOpt); + } + } else { + processIdf = true; + } + + int minDf = 1; + if (cmdLine.hasOption(minDFOpt)) { + minDf = Integer.parseInt(cmdLine.getValue(minDFOpt).toString()); + } + int maxDFPercent = 99; + if (cmdLine.hasOption(maxDFPercentOpt)) { + maxDFPercent = Integer.parseInt(cmdLine.getValue(maxDFPercentOpt).toString()); + } + double maxDFSigma = -1.0; + if (cmdLine.hasOption(maxDFSigmaOpt)) { + maxDFSigma = Double.parseDouble(cmdLine.getValue(maxDFSigmaOpt).toString()); + } + + float norm = PartialVectorMerger.NO_NORMALIZING; + if (cmdLine.hasOption(powerOpt)) { + String power = cmdLine.getValue(powerOpt).toString(); + if ("INF".equals(power)) { + norm = Float.POSITIVE_INFINITY; + } else { + norm = Float.parseFloat(power); + } + } + + boolean logNormalize = false; + if (cmdLine.hasOption(logNormalizeOpt)) { + logNormalize = true; + } + log.info("Tokenizing documents in {}", inputDir); + Configuration conf = getConf(); + Path tokenizedPath = new Path(outputDir, DocumentProcessor.TOKENIZED_DOCUMENT_OUTPUT_FOLDER); + //TODO: move this into DictionaryVectorizer , and then fold SparseVectorsFrom with EncodedVectorsFrom + // to have one framework for all of this. + DocumentProcessor.tokenizeDocuments(inputDir, analyzerClass, tokenizedPath, conf); + + boolean sequentialAccessOutput = false; + if (cmdLine.hasOption(sequentialAccessVectorOpt)) { + sequentialAccessOutput = true; + } + + boolean namedVectors = false; + if (cmdLine.hasOption(namedVectorOpt)) { + namedVectors = true; + } + boolean shouldPrune = maxDFSigma >= 0.0 || maxDFPercent > 0.00; + String tfDirName = shouldPrune + ? DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER + "-toprune" + : DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER; + log.info("Creating Term Frequency Vectors"); + if (processIdf) { + DictionaryVectorizer.createTermFrequencyVectors(tokenizedPath, + outputDir, + tfDirName, + conf, + minSupport, + maxNGramSize, + minLLRValue, + -1.0f, + false, + reduceTasks, + chunkSize, + sequentialAccessOutput, + namedVectors); + } else { + DictionaryVectorizer.createTermFrequencyVectors(tokenizedPath, + outputDir, + tfDirName, + conf, + minSupport, + maxNGramSize, + minLLRValue, + norm, + logNormalize, + reduceTasks, + chunkSize, + sequentialAccessOutput, + namedVectors); + } + + Pair<Long[], List<Path>> docFrequenciesFeatures = null; + // Should document frequency features be processed + if (shouldPrune || processIdf) { + log.info("Calculating IDF"); + docFrequenciesFeatures = + TFIDFConverter.calculateDF(new Path(outputDir, tfDirName), outputDir, conf, chunkSize); + } + + long maxDF = maxDFPercent; //if we are pruning by std dev, then this will get changed + if (shouldPrune) { + long vectorCount = docFrequenciesFeatures.getFirst()[1]; + if (maxDFSigma >= 0.0) { + Path dfDir = new Path(outputDir, TFIDFConverter.WORDCOUNT_OUTPUT_FOLDER); + Path stdCalcDir = new Path(outputDir, HighDFWordsPruner.STD_CALC_DIR); + + // Calculate the standard deviation + double stdDev = BasicStats.stdDevForGivenMean(dfDir, stdCalcDir, 0.0, conf); + maxDF = (int) (100.0 * maxDFSigma * stdDev / vectorCount); + } + + long maxDFThreshold = (long) (vectorCount * (maxDF / 100.0f)); + + // Prune the term frequency vectors + Path tfDir = new Path(outputDir, tfDirName); + Path prunedTFDir = new Path(outputDir, DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER); + Path prunedPartialTFDir = + new Path(outputDir, DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER + "-partial"); + log.info("Pruning"); + if (processIdf) { + HighDFWordsPruner.pruneVectors(tfDir, + prunedTFDir, + prunedPartialTFDir, + maxDFThreshold, + minDf, + conf, + docFrequenciesFeatures, + -1.0f, + false, + reduceTasks); + } else { + HighDFWordsPruner.pruneVectors(tfDir, + prunedTFDir, + prunedPartialTFDir, + maxDFThreshold, + minDf, + conf, + docFrequenciesFeatures, + norm, + logNormalize, + reduceTasks); + } + HadoopUtil.delete(new Configuration(conf), tfDir); + } + if (processIdf) { + TFIDFConverter.processTfIdf( + new Path(outputDir, DictionaryVectorizer.DOCUMENT_VECTOR_OUTPUT_FOLDER), + outputDir, conf, docFrequenciesFeatures, minDf, maxDF, norm, logNormalize, + sequentialAccessOutput, namedVectors, reduceTasks); + } + } catch (OptionException e) { + log.error("Exception", e); + CommandLineUtil.printHelp(group); + } + return 0; + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TF.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TF.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TF.java new file mode 100644 index 0000000..1818580 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TF.java @@ -0,0 +1,30 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +/** + * {@link Weight} based on term frequency only + */ +public class TF implements Weight { + + @Override + public double calculate(int tf, int df, int length, int numDocs) { + //ignore length + return tf; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java new file mode 100644 index 0000000..238fa03 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/TFIDF.java @@ -0,0 +1,31 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +import org.apache.lucene.search.similarities.ClassicSimilarity; +//TODO: add a new class that supports arbitrary Lucene similarity implementations +public class TFIDF implements Weight { + + private final ClassicSimilarity sim = new ClassicSimilarity(); + + @Override + public double calculate(int tf, int df, int length, int numDocs) { + // ignore length + return sim.tf(tf) * sim.idf(df, numDocs); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java new file mode 100644 index 0000000..45f0043 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Vectorizer.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +import org.apache.hadoop.fs.Path; + +import java.io.IOException; + +public interface Vectorizer { + + void createVectors(Path input, Path output, VectorizerConfig config) + throws IOException, ClassNotFoundException, InterruptedException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java new file mode 100644 index 0000000..edaf2f3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/VectorizerConfig.java @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +import org.apache.hadoop.conf.Configuration; + +/** + * The config for a Vectorizer. Not all implementations need use all variables. + */ +public final class VectorizerConfig { + + private Configuration conf; + private String analyzerClassName; + private String encoderName; + private boolean sequentialAccess; + private boolean namedVectors; + private int cardinality; + private String encoderClass; + private String tfDirName; + private int minSupport; + private int maxNGramSize; + private float minLLRValue; + private float normPower; + private boolean logNormalize; + private int numReducers; + private int chunkSizeInMegabytes; + + public VectorizerConfig(Configuration conf, + String analyzerClassName, + String encoderClass, + String encoderName, + boolean sequentialAccess, + boolean namedVectors, + int cardinality) { + this.conf = conf; + this.analyzerClassName = analyzerClassName; + this.encoderClass = encoderClass; + this.encoderName = encoderName; + this.sequentialAccess = sequentialAccess; + this.namedVectors = namedVectors; + this.cardinality = cardinality; + } + + public Configuration getConf() { + return conf; + } + + public void setConf(Configuration conf) { + this.conf = conf; + } + + public String getAnalyzerClassName() { + return analyzerClassName; + } + + public void setAnalyzerClassName(String analyzerClassName) { + this.analyzerClassName = analyzerClassName; + } + + public String getEncoderName() { + return encoderName; + } + + public void setEncoderName(String encoderName) { + this.encoderName = encoderName; + } + + public boolean isSequentialAccess() { + return sequentialAccess; + } + + public void setSequentialAccess(boolean sequentialAccess) { + this.sequentialAccess = sequentialAccess; + } + + + public String getTfDirName() { + return tfDirName; + } + + public void setTfDirName(String tfDirName) { + this.tfDirName = tfDirName; + } + + public boolean isNamedVectors() { + return namedVectors; + } + + public void setNamedVectors(boolean namedVectors) { + this.namedVectors = namedVectors; + } + + public int getCardinality() { + return cardinality; + } + + public void setCardinality(int cardinality) { + this.cardinality = cardinality; + } + + public String getEncoderClass() { + return encoderClass; + } + + public void setEncoderClass(String encoderClass) { + this.encoderClass = encoderClass; + } + + public int getMinSupport() { + return minSupport; + } + + public void setMinSupport(int minSupport) { + this.minSupport = minSupport; + } + + public int getMaxNGramSize() { + return maxNGramSize; + } + + public void setMaxNGramSize(int maxNGramSize) { + this.maxNGramSize = maxNGramSize; + } + + public float getMinLLRValue() { + return minLLRValue; + } + + public void setMinLLRValue(float minLLRValue) { + this.minLLRValue = minLLRValue; + } + + public float getNormPower() { + return normPower; + } + + public void setNormPower(float normPower) { + this.normPower = normPower; + } + + public boolean isLogNormalize() { + return logNormalize; + } + + public void setLogNormalize(boolean logNormalize) { + this.logNormalize = logNormalize; + } + + public int getNumReducers() { + return numReducers; + } + + public void setNumReducers(int numReducers) { + this.numReducers = numReducers; + } + + public int getChunkSizeInMegabytes() { + return chunkSizeInMegabytes; + } + + public void setChunkSizeInMegabytes(int chunkSizeInMegabytes) { + this.chunkSizeInMegabytes = chunkSizeInMegabytes; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Weight.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Weight.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Weight.java new file mode 100644 index 0000000..e36159d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/Weight.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer; + +public interface Weight { + + /** + * Experimental + * + * @param tf term freq + * @param df doc freq + * @param length Length of the document + * @param numDocs the total number of docs + * @return The weight + */ + double calculate(int tf, int df, int length, int numDocs); +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java new file mode 100644 index 0000000..54cadbd --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocCombiner.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import java.io.IOException; + +import org.apache.hadoop.mapreduce.Reducer; + +/** Combiner for pass1 of the CollocationDriver. Combines frequencies for values for the same key */ +public class CollocCombiner extends Reducer<GramKey, Gram, GramKey, Gram> { + + @Override + protected void reduce(GramKey key, Iterable<Gram> values, Context context) throws IOException, InterruptedException { + + int freq = 0; + Gram value = null; + + // accumulate frequencies from values, preserve the last value + // to write to the context. + for (Gram value1 : values) { + value = value1; + freq += value.getFrequency(); + } + + if (value != null) { + value.setFrequency(freq); + context.write(key, value); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java new file mode 100644 index 0000000..a07ddbd --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocDriver.java @@ -0,0 +1,284 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.lucene.AnalyzerUtils; +import org.apache.mahout.vectorizer.DocumentProcessor; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Driver for LLR Collocation discovery mapreduce job */ +public final class CollocDriver extends AbstractJob { + //public static final String DEFAULT_OUTPUT_DIRECTORY = "output"; + + public static final String SUBGRAM_OUTPUT_DIRECTORY = "subgrams"; + + public static final String NGRAM_OUTPUT_DIRECTORY = "ngrams"; + + public static final String EMIT_UNIGRAMS = "emit-unigrams"; + + public static final boolean DEFAULT_EMIT_UNIGRAMS = false; + + private static final int DEFAULT_MAX_NGRAM_SIZE = 2; + + private static final int DEFAULT_PASS1_NUM_REDUCE_TASKS = 1; + + private static final Logger log = LoggerFactory.getLogger(CollocDriver.class); + + public static void main(String[] args) throws Exception { + ToolRunner.run(new CollocDriver(), args); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.numReducersOption().create()); + + addOption("maxNGramSize", + "ng", + "(Optional) The max size of ngrams to create (2 = bigrams, 3 = trigrams, etc) default: 2", + String.valueOf(DEFAULT_MAX_NGRAM_SIZE)); + addOption("minSupport", "s", "(Optional) Minimum Support. Default Value: " + + CollocReducer.DEFAULT_MIN_SUPPORT, String.valueOf(CollocReducer.DEFAULT_MIN_SUPPORT)); + addOption("minLLR", "ml", "(Optional)The minimum Log Likelihood Ratio(Float) Default is " + + LLRReducer.DEFAULT_MIN_LLR, String.valueOf(LLRReducer.DEFAULT_MIN_LLR)); + addOption(DefaultOptionCreator.overwriteOption().create()); + addOption("analyzerName", "a", "The class name of the analyzer to use for preprocessing", null); + + addFlag("preprocess", "p", "If set, input is SequenceFile<Text,Text> where the value is the document, " + + " which will be tokenized using the specified analyzer."); + addFlag("unigram", "u", "If set, unigrams will be emitted in the final output alongside collocations"); + + Map<String, List<String>> argMap = parseArguments(args); + + if (argMap == null) { + return -1; + } + + Path input = getInputPath(); + Path output = getOutputPath(); + + int maxNGramSize = DEFAULT_MAX_NGRAM_SIZE; + if (hasOption("maxNGramSize")) { + try { + maxNGramSize = Integer.parseInt(getOption("maxNGramSize")); + } catch (NumberFormatException ex) { + log.warn("Could not parse ngram size option"); + } + } + log.info("Maximum n-gram size is: {}", maxNGramSize); + + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + + int minSupport = CollocReducer.DEFAULT_MIN_SUPPORT; + if (getOption("minSupport") != null) { + minSupport = Integer.parseInt(getOption("minSupport")); + } + log.info("Minimum Support value: {}", minSupport); + + float minLLRValue = LLRReducer.DEFAULT_MIN_LLR; + if (getOption("minLLR") != null) { + minLLRValue = Float.parseFloat(getOption("minLLR")); + } + log.info("Minimum LLR value: {}", minLLRValue); + + int reduceTasks = DEFAULT_PASS1_NUM_REDUCE_TASKS; + if (getOption("maxRed") != null) { + reduceTasks = Integer.parseInt(getOption("maxRed")); + } + log.info("Number of pass1 reduce tasks: {}", reduceTasks); + + boolean emitUnigrams = argMap.containsKey("emitUnigrams"); + + if (argMap.containsKey("preprocess")) { + log.info("Input will be preprocessed"); + Class<? extends Analyzer> analyzerClass = StandardAnalyzer.class; + if (getOption("analyzerName") != null) { + String className = getOption("analyzerName"); + analyzerClass = Class.forName(className).asSubclass(Analyzer.class); + // try instantiating it, b/c there isn't any point in setting it if + // you can't instantiate it + AnalyzerUtils.createAnalyzer(analyzerClass); + } + + Path tokenizedPath = new Path(output, DocumentProcessor.TOKENIZED_DOCUMENT_OUTPUT_FOLDER); + + DocumentProcessor.tokenizeDocuments(input, analyzerClass, tokenizedPath, getConf()); + input = tokenizedPath; + } else { + log.info("Input will NOT be preprocessed"); + } + + // parse input and extract collocations + long ngramCount = + generateCollocations(input, output, getConf(), emitUnigrams, maxNGramSize, reduceTasks, minSupport); + + // tally collocations and perform LLR calculation + computeNGramsPruneByLLR(output, getConf(), ngramCount, emitUnigrams, minLLRValue, reduceTasks); + + return 0; + } + + /** + * Generate all ngrams for the {@link org.apache.mahout.vectorizer.DictionaryVectorizer} job + * + * @param input + * input path containing tokenized documents + * @param output + * output path where ngrams are generated including unigrams + * @param baseConf + * job configuration + * @param maxNGramSize + * minValue = 2. + * @param minSupport + * minimum support to prune ngrams including unigrams + * @param minLLRValue + * minimum threshold to prune ngrams + * @param reduceTasks + * number of reducers used + */ + public static void generateAllGrams(Path input, + Path output, + Configuration baseConf, + int maxNGramSize, + int minSupport, + float minLLRValue, + int reduceTasks) + throws IOException, InterruptedException, ClassNotFoundException { + // parse input and extract collocations + long ngramCount = generateCollocations(input, output, baseConf, true, maxNGramSize, reduceTasks, minSupport); + + // tally collocations and perform LLR calculation + computeNGramsPruneByLLR(output, baseConf, ngramCount, true, minLLRValue, reduceTasks); + } + + /** + * pass1: generate collocations, ngrams + */ + private static long generateCollocations(Path input, + Path output, + Configuration baseConf, + boolean emitUnigrams, + int maxNGramSize, + int reduceTasks, + int minSupport) + throws IOException, ClassNotFoundException, InterruptedException { + + Configuration con = new Configuration(baseConf); + con.setBoolean(EMIT_UNIGRAMS, emitUnigrams); + con.setInt(CollocMapper.MAX_SHINGLE_SIZE, maxNGramSize); + con.setInt(CollocReducer.MIN_SUPPORT, minSupport); + + Job job = new Job(con); + job.setJobName(CollocDriver.class.getSimpleName() + ".generateCollocations:" + input); + job.setJarByClass(CollocDriver.class); + + job.setMapOutputKeyClass(GramKey.class); + job.setMapOutputValueClass(Gram.class); + job.setPartitionerClass(GramKeyPartitioner.class); + job.setGroupingComparatorClass(GramKeyGroupComparator.class); + + job.setOutputKeyClass(Gram.class); + job.setOutputValueClass(Gram.class); + + job.setCombinerClass(CollocCombiner.class); + + FileInputFormat.setInputPaths(job, input); + + Path outputPath = new Path(output, SUBGRAM_OUTPUT_DIRECTORY); + FileOutputFormat.setOutputPath(job, outputPath); + + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setMapperClass(CollocMapper.class); + + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setReducerClass(CollocReducer.class); + job.setNumReduceTasks(reduceTasks); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + return job.getCounters().findCounter(CollocMapper.Count.NGRAM_TOTAL).getValue(); + } + + /** + * pass2: perform the LLR calculation + */ + private static void computeNGramsPruneByLLR(Path output, + Configuration baseConf, + long nGramTotal, + boolean emitUnigrams, + float minLLRValue, + int reduceTasks) + throws IOException, InterruptedException, ClassNotFoundException { + Configuration conf = new Configuration(baseConf); + conf.setLong(LLRReducer.NGRAM_TOTAL, nGramTotal); + conf.setBoolean(EMIT_UNIGRAMS, emitUnigrams); + conf.setFloat(LLRReducer.MIN_LLR, minLLRValue); + + Job job = new Job(conf); + job.setJobName(CollocDriver.class.getSimpleName() + ".computeNGrams: " + output); + job.setJarByClass(CollocDriver.class); + + job.setMapOutputKeyClass(Gram.class); + job.setMapOutputValueClass(Gram.class); + + job.setOutputKeyClass(Text.class); + job.setOutputValueClass(DoubleWritable.class); + + FileInputFormat.setInputPaths(job, new Path(output, SUBGRAM_OUTPUT_DIRECTORY)); + Path outPath = new Path(output, NGRAM_OUTPUT_DIRECTORY); + FileOutputFormat.setOutputPath(job, outPath); + + job.setMapperClass(Mapper.class); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setReducerClass(LLRReducer.class); + job.setNumReduceTasks(reduceTasks); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java new file mode 100644 index 0000000..fd99293 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocMapper.java @@ -0,0 +1,178 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.lucene.analysis.shingle.ShingleFilter; +import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; +import org.apache.lucene.analysis.tokenattributes.TypeAttribute; +import org.apache.mahout.common.StringTuple; +import org.apache.mahout.common.lucene.IteratorTokenStream; +import org.apache.mahout.math.function.ObjectIntProcedure; +import org.apache.mahout.math.map.OpenObjectIntHashMap; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Pass 1 of the Collocation discovery job which generated ngrams and emits ngrams an their component n-1grams. + * Input is a SequeceFile<Text,StringTuple>, where the key is a document id and the value is the tokenized documents. + * <p/> + */ +public class CollocMapper extends Mapper<Text, StringTuple, GramKey, Gram> { + + private static final byte[] EMPTY = new byte[0]; + + public static final String MAX_SHINGLE_SIZE = "maxShingleSize"; + + private static final int DEFAULT_MAX_SHINGLE_SIZE = 2; + + public enum Count { + NGRAM_TOTAL + } + + private static final Logger log = LoggerFactory.getLogger(CollocMapper.class); + + private int maxShingleSize; + + private boolean emitUnigrams; + + /** + * Collocation finder: pass 1 map phase. + * <p/> + * Receives a token stream which gets passed through a Lucene ShingleFilter. The ShingleFilter delivers ngrams of + * the appropriate size which are then decomposed into head and tail subgrams which are collected in the + * following manner + * <p/> + * <pre> + * k:head_key, v:head_subgram + * k:head_key,ngram_key, v:ngram + * k:tail_key, v:tail_subgram + * k:tail_key,ngram_key, v:ngram + * </pre> + * <p/> + * The 'head' or 'tail' prefix is used to specify whether the subgram in question is the head or tail of the + * ngram. In this implementation the head of the ngram is a (n-1)gram, and the tail is a (1)gram. + * <p/> + * For example, given 'click and clack' and an ngram length of 3: + * <pre> + * k: head_'click and' v:head_'click and' + * k: head_'click and',ngram_'click and clack' v:ngram_'click and clack' + * k: tail_'clack', v:tail_'clack' + * k: tail_'clack',ngram_'click and clack' v:ngram_'click and clack' + * </pre> + * <p/> + * Also counts the total number of ngrams encountered and adds it to the counter + * CollocDriver.Count.NGRAM_TOTAL + * </p> + * + * @throws IOException if there's a problem with the ShingleFilter reading data or the collector collecting output. + */ + @Override + protected void map(Text key, StringTuple value, final Context context) throws IOException, InterruptedException { + + try (ShingleFilter sf = new ShingleFilter(new IteratorTokenStream(value.getEntries().iterator()), maxShingleSize)){ + sf.reset(); + int count = 0; // ngram count + + OpenObjectIntHashMap<String> ngrams = + new OpenObjectIntHashMap<>(value.getEntries().size() * (maxShingleSize - 1)); + OpenObjectIntHashMap<String> unigrams = new OpenObjectIntHashMap<>(value.getEntries().size()); + + do { + String term = sf.getAttribute(CharTermAttribute.class).toString(); + String type = sf.getAttribute(TypeAttribute.class).type(); + if ("shingle".equals(type)) { + count++; + ngrams.adjustOrPutValue(term, 1, 1); + } else if (emitUnigrams && !term.isEmpty()) { // unigram + unigrams.adjustOrPutValue(term, 1, 1); + } + } while (sf.incrementToken()); + + final GramKey gramKey = new GramKey(); + + ngrams.forEachPair(new ObjectIntProcedure<String>() { + @Override + public boolean apply(String term, int frequency) { + // obtain components, the leading (n-1)gram and the trailing unigram. + int i = term.lastIndexOf(' '); // TODO: fix for non-whitespace delimited languages. + if (i != -1) { // bigram, trigram etc + + try { + Gram ngram = new Gram(term, frequency, Gram.Type.NGRAM); + Gram head = new Gram(term.substring(0, i), frequency, Gram.Type.HEAD); + Gram tail = new Gram(term.substring(i + 1), frequency, Gram.Type.TAIL); + + gramKey.set(head, EMPTY); + context.write(gramKey, head); + + gramKey.set(head, ngram.getBytes()); + context.write(gramKey, ngram); + + gramKey.set(tail, EMPTY); + context.write(gramKey, tail); + + gramKey.set(tail, ngram.getBytes()); + context.write(gramKey, ngram); + + } catch (IOException | InterruptedException e) { + throw new IllegalStateException(e); + } + } + return true; + } + }); + + unigrams.forEachPair(new ObjectIntProcedure<String>() { + @Override + public boolean apply(String term, int frequency) { + try { + Gram unigram = new Gram(term, frequency, Gram.Type.UNIGRAM); + gramKey.set(unigram, EMPTY); + context.write(gramKey, unigram); + } catch (IOException | InterruptedException e) { + throw new IllegalStateException(e); + } + return true; + } + }); + + context.getCounter(Count.NGRAM_TOTAL).increment(count); + sf.end(); + } + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + this.maxShingleSize = conf.getInt(MAX_SHINGLE_SIZE, DEFAULT_MAX_SHINGLE_SIZE); + + this.emitUnigrams = conf.getBoolean(CollocDriver.EMIT_UNIGRAMS, CollocDriver.DEFAULT_EMIT_UNIGRAMS); + + if (log.isInfoEnabled()) { + log.info("Max Ngram size is {}", this.maxShingleSize); + log.info("Emit Unitgrams is {}", emitUnigrams); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java new file mode 100644 index 0000000..1fe13e3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/CollocReducer.java @@ -0,0 +1,176 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import java.io.IOException; +import java.util.Iterator; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.mapreduce.Reducer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Reducer for Pass 1 of the collocation identification job. Generates counts for ngrams and subgrams. + */ +public class CollocReducer extends Reducer<GramKey, Gram, Gram, Gram> { + + private static final Logger log = LoggerFactory.getLogger(CollocReducer.class); + + public static final String MIN_SUPPORT = "minSupport"; + + public static final int DEFAULT_MIN_SUPPORT = 2; + + public enum Skipped { + LESS_THAN_MIN_SUPPORT, MALFORMED_KEY_TUPLE, MALFORMED_TUPLE, MALFORMED_TYPES, MALFORMED_UNIGRAM + } + + private int minSupport; + + /** + * collocation finder: pass 1 reduce phase: + * <p/> + * given input from the mapper, + * + * <pre> + * k:head_subgram,ngram, v:ngram:partial freq + * k:head_subgram v:head_subgram:partial freq + * k:tail_subgram,ngram, v:ngram:partial freq + * k:tail_subgram v:tail_subgram:partial freq + * k:unigram v:unigram:partial freq + * </pre> + * sum gram frequencies and output for llr calculation + * <p/> + * output is: + * <pre> + * k:ngram:ngramfreq v:head_subgram:head_subgramfreq + * k:ngram:ngramfreq v:tail_subgram:tail_subgramfreq + * k:unigram:unigramfreq v:unigram:unigramfreq + * </pre> + * Each ngram's frequency is essentially counted twice, once for head, once for tail. + * frequency should be the same for the head and tail. Fix this to count only for the + * head and move the count into the value? + */ + @Override + protected void reduce(GramKey key, Iterable<Gram> values, Context context) throws IOException, InterruptedException { + + Gram.Type keyType = key.getType(); + + if (keyType == Gram.Type.UNIGRAM) { + // sum frequencies for unigrams. + processUnigram(values.iterator(), context); + } else if (keyType == Gram.Type.HEAD || keyType == Gram.Type.TAIL) { + // sum frequencies for subgrams, ngram and collect for each ngram. + processSubgram(values.iterator(), context); + } else { + context.getCounter(Skipped.MALFORMED_TYPES).increment(1); + } + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + this.minSupport = conf.getInt(MIN_SUPPORT, DEFAULT_MIN_SUPPORT); + + boolean emitUnigrams = conf.getBoolean(CollocDriver.EMIT_UNIGRAMS, CollocDriver.DEFAULT_EMIT_UNIGRAMS); + + log.info("Min support is {}", minSupport); + log.info("Emit Unitgrams is {}", emitUnigrams); + } + + /** + * Sum frequencies for unigrams and deliver to the collector + */ + protected void processUnigram(Iterator<Gram> values, Context context) + throws IOException, InterruptedException { + + int freq = 0; + Gram value = null; + + // accumulate frequencies from values. + while (values.hasNext()) { + value = values.next(); + freq += value.getFrequency(); + } + + if (freq < minSupport) { + context.getCounter(Skipped.LESS_THAN_MIN_SUPPORT).increment(1); + return; + } + + value.setFrequency(freq); + context.write(value, value); + + } + + /** Sum frequencies for subgram, ngrams and deliver ngram, subgram pairs to the collector. + * <p/> + * Sort order guarantees that the subgram/subgram pairs will be seen first and then + * subgram/ngram1 pairs, subgram/ngram2 pairs ... subgram/ngramN pairs, so frequencies for + * ngrams can be calcualted here as well. + * <p/> + * We end up calculating frequencies for ngrams for each sugram (head, tail) here, which is + * some extra work. + * @throws InterruptedException + */ + protected void processSubgram(Iterator<Gram> values, Context context) + throws IOException, InterruptedException { + + Gram subgram = null; + Gram currentNgram = null; + + while (values.hasNext()) { + Gram value = values.next(); + + if (value.getType() == Gram.Type.HEAD || value.getType() == Gram.Type.TAIL) { + // collect frequency for subgrams. + if (subgram == null) { + subgram = new Gram(value); + } else { + subgram.incrementFrequency(value.getFrequency()); + } + } else if (!value.equals(currentNgram)) { + // we've collected frequency for all subgrams and we've encountered a new ngram. + // collect the old ngram if there was one and we have sufficient support and + // create the new ngram. + if (currentNgram != null) { + if (currentNgram.getFrequency() < minSupport) { + context.getCounter(Skipped.LESS_THAN_MIN_SUPPORT).increment(1); + } else { + context.write(currentNgram, subgram); + } + } + + currentNgram = new Gram(value); + } else { + currentNgram.incrementFrequency(value.getFrequency()); + } + } + + // collect last ngram. + if (currentNgram != null) { + if (currentNgram.getFrequency() < minSupport) { + context.getCounter(Skipped.LESS_THAN_MIN_SUPPORT).increment(1); + return; + } + + context.write(currentNgram, subgram); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java new file mode 100644 index 0000000..58234b3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/Gram.java @@ -0,0 +1,239 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.charset.CharacterCodingException; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.io.BinaryComparable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.WritableComparable; +import org.apache.mahout.math.Varint; + +/** + * Writable for holding data generated from the collocation discovery jobs. Depending on the job configuration + * gram may be one or more words. In some contexts this is used to hold a complete ngram, while in others it + * holds a part of an existing ngram (subgram). Tracks the frequency of the gram and its position in the ngram + * in which is was found. + */ +public class Gram extends BinaryComparable implements WritableComparable<BinaryComparable> { + + public enum Type { + HEAD('h'), + TAIL('t'), + UNIGRAM('u'), + NGRAM('n'); + + private final char x; + + Type(char c) { + this.x = c; + } + + @Override + public String toString() { + return String.valueOf(x); + } + } + + private byte[] bytes; + private int length; + private int frequency; + + public Gram() { + + } + + /** + * Copy constructor + */ + public Gram(Gram other) { + frequency = other.frequency; + length = other.length; + bytes = other.bytes.clone(); + } + + /** + * Create an gram with a frequency of 1 + * + * @param ngram + * the gram string + * @param type + * whether the gram is at the head or tail of its text unit or it is a unigram + */ + public Gram(String ngram, Type type) { + this(ngram, 1, type); + } + + + /** + * + * Create a gram with the specified frequency. + * + * @param ngram + * the gram string + * @param frequency + * the gram frequency + * @param type + * whether the gram is at the head of its text unit or tail or unigram + */ + public Gram(String ngram, int frequency, Type type) { + Preconditions.checkNotNull(ngram); + try { + // extra character is used for storing type which is part + // of the sort key. + ByteBuffer bb = Text.encode('\0' + ngram, true); + bytes = bb.array(); + length = bb.limit(); + } catch (CharacterCodingException e) { + throw new IllegalStateException("Should not have happened ",e); + } + + encodeType(type, bytes, 0); + this.frequency = frequency; + } + + + @Override + public byte[] getBytes() { + return bytes; + } + + @Override + public int getLength() { + return length; + } + + /** + * @return the gram is at the head of its text unit or tail or unigram. + */ + public Type getType() { + return decodeType(bytes, 0); + } + + /** + * @return gram term string + */ + public String getString() { + try { + return Text.decode(bytes, 1, length - 1); + } catch (CharacterCodingException e) { + throw new IllegalStateException("Should not have happened " + e); + } + } + + /** + * @return gram frequency + */ + public int getFrequency() { + return frequency; + } + + /** + * @param frequency + * gram's frequency + */ + public void setFrequency(int frequency) { + this.frequency = frequency; + } + + public void incrementFrequency(int i) { + this.frequency += i; + } + + @Override + public void readFields(DataInput in) throws IOException { + int newLength = Varint.readUnsignedVarInt(in); + setCapacity(newLength, false); + in.readFully(bytes, 0, newLength); + int newFrequency = Varint.readUnsignedVarInt(in); + length = newLength; + frequency = newFrequency; + } + + @Override + public void write(DataOutput out) throws IOException { + Varint.writeUnsignedVarInt(length, out); + out.write(bytes, 0, length); + Varint.writeUnsignedVarInt(frequency, out); + } + + /* Cribbed from o.a.hadoop.io.Text: + * Sets the capacity of this object to <em>at least</em> + * {@code len} bytes. If the current buffer is longer, + * then the capacity and existing content of the buffer are + * unchanged. If {@code len} is larger + * than the current capacity, this object's capacity is + * increased to match. + * @param len the number of bytes we need + * @param keepData should the old data be kept + */ + private void setCapacity(int len, boolean keepData) { + len++; // extra byte to hold type + if (bytes == null || bytes.length < len) { + byte[] newBytes = new byte[len]; + if (bytes != null && keepData) { + System.arraycopy(bytes, 0, newBytes, 0, length); + } + bytes = newBytes; + } + } + + @Override + public String toString() { + return '\'' + getString() + "'[" + getType() + "]:" + frequency; + } + + public static void encodeType(Type type, byte[] buf, int offset) { + switch (type) { + case HEAD: + buf[offset] = 0x1; + break; + case TAIL: + buf[offset] = 0x2; + break; + case UNIGRAM: + buf[offset] = 0x3; + break; + case NGRAM: + buf[offset] = 0x4; + break; + default: + throw new IllegalStateException("switch/case problem in encodeType"); + } + } + + public static Type decodeType(byte[] buf, int offset) { + switch (buf[offset]) { + case 0x1: + return Type.HEAD; + case 0x2: + return Type.TAIL; + case 0x3: + return Type.UNIGRAM; + case 0x4: + return Type.NGRAM; + default: + throw new IllegalStateException("switch/case problem in decodeType"); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java new file mode 100644 index 0000000..e33ed51 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKey.java @@ -0,0 +1,133 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.nio.charset.CharacterCodingException; + +import org.apache.hadoop.io.BinaryComparable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.WritableComparable; +import org.apache.mahout.math.Varint; +import org.apache.mahout.vectorizer.collocations.llr.Gram.Type; + +/** A GramKey, based on the identity fields of Gram (type, string) plus a byte[] used for secondary ordering */ +public final class GramKey extends BinaryComparable implements WritableComparable<BinaryComparable> { + + private int primaryLength; + private int length; + private byte[] bytes; + + public GramKey() { + + } + + /** create a GramKey based on the specified Gram and order + * + * @param gram + * @param order + */ + public GramKey(Gram gram, byte[] order) { + set(gram, order); + } + + /** set the gram held by this key */ + public void set(Gram gram, byte[] order) { + primaryLength = gram.getLength(); + length = primaryLength + order.length; + setCapacity(length, false); + System.arraycopy(gram.getBytes(), 0, bytes, 0, primaryLength); + if (order.length > 0) { + System.arraycopy(order, 0, bytes, primaryLength, order.length); + } + } + + @Override + public byte[] getBytes() { + return bytes; + } + + @Override + public int getLength() { + return length; + } + + public int getPrimaryLength() { + return primaryLength; + } + + @Override + public void readFields(DataInput in) throws IOException { + int newLength = Varint.readUnsignedVarInt(in); + int newPrimaryLength = Varint.readUnsignedVarInt(in); + setCapacity(newLength, false); + in.readFully(bytes, 0, newLength); + length = newLength; + primaryLength = newPrimaryLength; + + } + + @Override + public void write(DataOutput out) throws IOException { + Varint.writeUnsignedVarInt(length, out); + Varint.writeUnsignedVarInt(primaryLength, out); + out.write(bytes, 0, length); + } + + /* Cribbed from o.a.hadoop.io.Text: + * Sets the capacity of this object to <em>at least</em> + * {@code len} bytes. If the current buffer is longer, + * then the capacity and existing content of the buffer are + * unchanged. If {@code len} is larger + * than the current capacity, this object's capacity is + * increased to match. + * @param len the number of bytes we need + * @param keepData should the old data be kept + */ + private void setCapacity(int len, boolean keepData) { + if (bytes == null || bytes.length < len) { + byte[] newBytes = new byte[len]; + if (bytes != null && keepData) { + System.arraycopy(bytes, 0, newBytes, 0, length); + } + bytes = newBytes; + } + } + + /** + * @return the gram is at the head of its text unit or tail or unigram. + */ + public Type getType() { + return Gram.decodeType(bytes, 0); + } + + public String getPrimaryString() { + try { + return Text.decode(bytes, 1, primaryLength - 1); + } catch (CharacterCodingException e) { + throw new IllegalStateException(e); + } + } + + @Override + public String toString() { + return '\'' + getPrimaryString() + "'[" + getType() + ']'; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java new file mode 100644 index 0000000..7b73d71 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyGroupComparator.java @@ -0,0 +1,43 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.io.WritableComparator; + +import java.io.Serializable; + +/** Group GramKeys based on their Gram, ignoring the secondary sort key, so that all keys with the same Gram are sent + * to the same call of the reduce method, sorted in natural order (for GramKeys). + */ +class GramKeyGroupComparator extends WritableComparator implements Serializable { + + GramKeyGroupComparator() { + super(GramKey.class, true); + } + + @Override + public int compare(WritableComparable a, WritableComparable b) { + GramKey gka = (GramKey) a; + GramKey gkb = (GramKey) b; + + return WritableComparator.compareBytes(gka.getBytes(), 0, gka.getPrimaryLength(), + gkb.getBytes(), 0, gkb.getPrimaryLength()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java new file mode 100644 index 0000000..a68339f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/GramKeyPartitioner.java @@ -0,0 +1,40 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import org.apache.hadoop.mapreduce.Partitioner; + +/** + * Partition GramKeys based on their Gram, ignoring the secondary sort key so that all GramKeys with the same + * gram are sent to the same partition. + */ +public final class GramKeyPartitioner extends Partitioner<GramKey, Gram> { + + @Override + public int getPartition(GramKey key, Gram value, int numPartitions) { + int hash = 1; + byte[] bytes = key.getBytes(); + int length = key.getPrimaryLength(); + // Copied from WritableComparator.hashBytes(); skips first byte, type byte + for (int i = 1; i < length; i++) { + hash = (31 * hash) + bytes[i]; + } + return (hash & Integer.MAX_VALUE) % numPartitions; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java new file mode 100644 index 0000000..d414416 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/collocations/llr/LLRReducer.java @@ -0,0 +1,170 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.collocations.llr; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.DoubleWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.stats.LogLikelihood; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Reducer for pass 2 of the collocation discovery job. Collects ngram and sub-ngram frequencies and performs + * the Log-likelihood ratio calculation. + */ +public class LLRReducer extends Reducer<Gram, Gram, Text, DoubleWritable> { + + /** Counter to track why a particlar entry was skipped */ + public enum Skipped { + EXTRA_HEAD, EXTRA_TAIL, MISSING_HEAD, MISSING_TAIL, LESS_THAN_MIN_LLR, LLR_CALCULATION_ERROR, + } + + private static final Logger log = LoggerFactory.getLogger(LLRReducer.class); + + public static final String NGRAM_TOTAL = "ngramTotal"; + public static final String MIN_LLR = "minLLR"; + public static final float DEFAULT_MIN_LLR = 1.0f; + + private long ngramTotal; + private float minLLRValue; + private boolean emitUnigrams; + private final LLCallback ll; + + /** + * Perform LLR calculation, input is: k:ngram:ngramFreq v:(h_|t_)subgram:subgramfreq N = ngram total + * + * Each ngram will have 2 subgrams, a head and a tail, referred to as A and B respectively below. + * + * A+ B: number of times a+b appear together: ngramFreq A+!B: number of times A appears without B: + * hSubgramFreq - ngramFreq !A+ B: number of times B appears without A: tSubgramFreq - ngramFreq !A+!B: + * number of times neither A or B appears (in that order): N - (subgramFreqA + subgramFreqB - ngramFreq) + */ + @Override + protected void reduce(Gram ngram, Iterable<Gram> values, Context context) throws IOException, InterruptedException { + + int[] gramFreq = {-1, -1}; + + if (ngram.getType() == Gram.Type.UNIGRAM && emitUnigrams) { + DoubleWritable dd = new DoubleWritable(ngram.getFrequency()); + Text t = new Text(ngram.getString()); + context.write(t, dd); + return; + } + // TODO better way to handle errors? Wouldn't an exception thrown here + // cause hadoop to re-try the job? + String[] gram = new String[2]; + for (Gram value : values) { + + int pos = value.getType() == Gram.Type.HEAD ? 0 : 1; + + if (gramFreq[pos] != -1) { + log.warn("Extra {} for {}, skipping", value.getType(), ngram); + if (value.getType() == Gram.Type.HEAD) { + context.getCounter(Skipped.EXTRA_HEAD).increment(1); + } else { + context.getCounter(Skipped.EXTRA_TAIL).increment(1); + } + return; + } + + gram[pos] = value.getString(); + gramFreq[pos] = value.getFrequency(); + } + + if (gramFreq[0] == -1) { + log.warn("Missing head for {}, skipping.", ngram); + context.getCounter(Skipped.MISSING_HEAD).increment(1); + return; + } + if (gramFreq[1] == -1) { + log.warn("Missing tail for {}, skipping", ngram); + context.getCounter(Skipped.MISSING_TAIL).increment(1); + return; + } + + long k11 = ngram.getFrequency(); /* a&b */ + long k12 = gramFreq[0] - ngram.getFrequency(); /* a&!b */ + long k21 = gramFreq[1] - ngram.getFrequency(); /* !b&a */ + long k22 = ngramTotal - (gramFreq[0] + gramFreq[1] - ngram.getFrequency()); /* !a&!b */ + + double llr; + try { + llr = ll.logLikelihoodRatio(k11, k12, k21, k22); + } catch (IllegalArgumentException ex) { + context.getCounter(Skipped.LLR_CALCULATION_ERROR).increment(1); + log.warn("Problem calculating LLR ratio for ngram {}, HEAD {}:{}, TAIL {}:{}, k11/k12/k21/k22: {}/{}/{}/{}", + ngram, gram[0], gramFreq[0], gram[1], gramFreq[1], k11, k12, k21, k22, ex); + return; + } + if (llr < minLLRValue) { + context.getCounter(Skipped.LESS_THAN_MIN_LLR).increment(1); + } else { + context.write(new Text(ngram.getString()), new DoubleWritable(llr)); + } + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + this.ngramTotal = conf.getLong(NGRAM_TOTAL, -1); + this.minLLRValue = conf.getFloat(MIN_LLR, DEFAULT_MIN_LLR); + + this.emitUnigrams = conf.getBoolean(CollocDriver.EMIT_UNIGRAMS, CollocDriver.DEFAULT_EMIT_UNIGRAMS); + + log.info("NGram Total: {}, Min LLR value: {}, Emit Unigrams: {}", + ngramTotal, minLLRValue, emitUnigrams); + + if (ngramTotal == -1) { + throw new IllegalStateException("No NGRAM_TOTAL available in job config"); + } + } + + public LLRReducer() { + this.ll = new ConcreteLLCallback(); + } + + /** + * plug in an alternate LL implementation, used for testing + * + * @param ll + * the LL to use. + */ + LLRReducer(LLCallback ll) { + this.ll = ll; + } + + /** + * provide interface so the input to the llr calculation can be captured for validation in unit testing + */ + public interface LLCallback { + double logLikelihoodRatio(long k11, long k12, long k21, long k22); + } + + /** concrete implementation delegates to LogLikelihood class */ + public static final class ConcreteLLCallback implements LLCallback { + @Override + public double logLikelihoodRatio(long k11, long k12, long k21, long k22) { + return LogLikelihood.logLikelihoodRatio(k11, k12, k21, k22); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java new file mode 100644 index 0000000..a8eacc3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/vectorizer/common/PartialVectorMergeReducer.java @@ -0,0 +1,89 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.vectorizer.common; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; + +/** + * Merges partial vectors in to a full sparse vector + */ +public class PartialVectorMergeReducer extends + Reducer<WritableComparable<?>, VectorWritable, WritableComparable<?>, VectorWritable> { + + private double normPower; + + private int dimension; + + private boolean sequentialAccess; + + private boolean namedVector; + + private boolean logNormalize; + + @Override + protected void reduce(WritableComparable<?> key, Iterable<VectorWritable> values, Context context) throws IOException, + InterruptedException { + + Vector vector = new RandomAccessSparseVector(dimension, 10); + for (VectorWritable value : values) { + vector.assign(value.get(), Functions.PLUS); + } + if (normPower != PartialVectorMerger.NO_NORMALIZING) { + if (logNormalize) { + vector = vector.logNormalize(normPower); + } else { + vector = vector.normalize(normPower); + } + } + if (sequentialAccess) { + vector = new SequentialAccessSparseVector(vector); + } + + if (namedVector) { + vector = new NamedVector(vector, key.toString()); + } + + // drop empty vectors. + if (vector.getNumNondefaultElements() > 0) { + VectorWritable vectorWritable = new VectorWritable(vector); + context.write(key, vectorWritable); + } + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + Configuration conf = context.getConfiguration(); + normPower = conf.getFloat(PartialVectorMerger.NORMALIZATION_POWER, PartialVectorMerger.NO_NORMALIZING); + dimension = conf.getInt(PartialVectorMerger.DIMENSION, Integer.MAX_VALUE); + sequentialAccess = conf.getBoolean(PartialVectorMerger.SEQUENTIAL_ACCESS, false); + namedVector = conf.getBoolean(PartialVectorMerger.NAMED_VECTOR, false); + logNormalize = conf.getBoolean(PartialVectorMerger.LOG_NORMALIZE, false); + } + +}
