http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInput.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInput.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInput.java new file mode 100644 index 0000000..6178f80 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInput.java @@ -0,0 +1,673 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.nio.charset.Charset; +import java.util.BitSet; + +import com.google.common.base.Preconditions; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator; +import org.apache.mahout.math.jet.random.sampling.RandomSampler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * A utility for splitting files in the input format used by the Bayes + * classifiers or anything else that has one item per line or SequenceFiles (key/value) + * into training and test sets in order to perform cross-validation. + * <p/> + * <p/> + * This class can be used to split directories of files or individual files into + * training and test sets using a number of different methods. + * <p/> + * When executed via {@link #splitDirectory(Path)} or {@link #splitFile(Path)}, + * the lines read from one or more, input files are written to files of the same + * name into the directories specified by the + * {@link #setTestOutputDirectory(Path)} and + * {@link #setTrainingOutputDirectory(Path)} methods. + * <p/> + * The composition of the test set is determined using one of the following + * approaches: + * <ul> + * <li>A contiguous set of items can be chosen from the input file(s) using the + * {@link #setTestSplitSize(int)} or {@link #setTestSplitPct(int)} methods. + * {@link #setTestSplitSize(int)} allocates a fixed number of items, while + * {@link #setTestSplitPct(int)} allocates a percentage of the original input, + * rounded up to the nearest integer. {@link #setSplitLocation(int)} is used to + * control the position in the input from which the test data is extracted and + * is described further below.</li> + * <li>A random sampling of items can be chosen from the input files(s) using + * the {@link #setTestRandomSelectionSize(int)} or + * {@link #setTestRandomSelectionPct(int)} methods, each choosing a fixed test + * set size or percentage of the input set size as described above. The + * {@link RandomSampler} class from {@code mahout-math} is used to create a sample + * of the appropriate size.</li> + * </ul> + * <p/> + * Any one of the methods above can be used to control the size of the test set. + * If multiple methods are called, a runtime exception will be thrown at + * execution time. + * <p/> + * The {@link #setSplitLocation(int)} method is passed an integer from 0 to 100 + * (inclusive) which is translated into the position of the start of the test + * data within the input file. + * <p/> + * Given: + * <ul> + * <li>an input file of 1500 lines</li> + * <li>a desired test data size of 10 percent</li> + * </ul> + * <p/> + * <ul> + * <li>A split location of 0 will cause the first 150 items appearing in the + * input set to be written to the test set.</li> + * <li>A split location of 25 will cause items 375-525 to be written to the test + * set.</li> + * <li>A split location of 100 will cause the last 150 items in the input to be + * written to the test set</li> + * </ul> + * The start of the split will always be adjusted forwards in order to ensure + * that the desired test set size is allocated. Split location has no effect is + * random sampling is employed. + */ +public class SplitInput extends AbstractJob { + + private static final Logger log = LoggerFactory.getLogger(SplitInput.class); + + private int testSplitSize = -1; + private int testSplitPct = -1; + private int splitLocation = 100; + private int testRandomSelectionSize = -1; + private int testRandomSelectionPct = -1; + private int keepPct = 100; + private Charset charset = Charsets.UTF_8; + private boolean useSequence; + private boolean useMapRed; + + private Path inputDirectory; + private Path trainingOutputDirectory; + private Path testOutputDirectory; + private Path mapRedOutputDirectory; + + private SplitCallback callback; + + @Override + public int run(String[] args) throws Exception { + + if (parseArgs(args)) { + splitDirectory(); + } + return 0; + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new SplitInput(), args); + } + + /** + * Configure this instance based on the command-line arguments contained within provided array. + * Calls {@link #validate()} to ensure consistency of configuration. + * + * @return true if the arguments were parsed successfully and execution should proceed. + * @throws Exception if there is a problem parsing the command-line arguments or the particular + * combination would violate class invariants. + */ + private boolean parseArgs(String[] args) throws Exception { + + addInputOption(); + addOption("trainingOutput", "tr", "The training data output directory", false); + addOption("testOutput", "te", "The test data output directory", false); + addOption("testSplitSize", "ss", "The number of documents held back as test data for each category", false); + addOption("testSplitPct", "sp", "The % of documents held back as test data for each category", false); + addOption("splitLocation", "sl", "Location for start of test data expressed as a percentage of the input file " + + "size (0=start, 50=middle, 100=end", false); + addOption("randomSelectionSize", "rs", "The number of items to be randomly selected as test data ", false); + addOption("randomSelectionPct", "rp", "Percentage of items to be randomly selected as test data when using " + + "mapreduce mode", false); + addOption("charset", "c", "The name of the character encoding of the input files (not needed if using " + + "SequenceFiles)", false); + addOption(buildOption("sequenceFiles", "seq", "Set if the input files are sequence files. Default is false", + false, false, "false")); + addOption(DefaultOptionCreator.methodOption().create()); + addOption(DefaultOptionCreator.overwriteOption().create()); + //TODO: extend this to sequential mode + addOption("keepPct", "k", "The percentage of total data to keep in map-reduce mode, the rest will be ignored. " + + "Default is 100%", false); + addOption("mapRedOutputDir", "mro", "Output directory for map reduce jobs", false); + + if (parseArguments(args) == null) { + return false; + } + + try { + inputDirectory = getInputPath(); + + useMapRed = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(DefaultOptionCreator.MAPREDUCE_METHOD); + + if (useMapRed) { + if (!hasOption("randomSelectionPct")) { + throw new OptionException(getCLIOption("randomSelectionPct"), + "must set randomSelectionPct when mapRed option is used"); + } + if (!hasOption("mapRedOutputDir")) { + throw new OptionException(getCLIOption("mapRedOutputDir"), + "mapRedOutputDir must be set when mapRed option is used"); + } + mapRedOutputDirectory = new Path(getOption("mapRedOutputDir")); + if (hasOption("keepPct")) { + keepPct = Integer.parseInt(getOption("keepPct")); + } + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), mapRedOutputDirectory); + } + } else { + if (!hasOption("trainingOutput") + || !hasOption("testOutput")) { + throw new OptionException(getCLIOption("trainingOutput"), + "trainingOutput and testOutput must be set if mapRed option is not used"); + } + if (!hasOption("testSplitSize") + && !hasOption("testSplitPct") + && !hasOption("randomSelectionPct") + && !hasOption("randomSelectionSize")) { + throw new OptionException(getCLIOption("testSplitSize"), + "must set one of test split size/percentage or randomSelectionSize/percentage"); + } + + trainingOutputDirectory = new Path(getOption("trainingOutput")); + testOutputDirectory = new Path(getOption("testOutput")); + FileSystem fs = trainingOutputDirectory.getFileSystem(getConf()); + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(fs.getConf(), trainingOutputDirectory); + HadoopUtil.delete(fs.getConf(), testOutputDirectory); + } + fs.mkdirs(trainingOutputDirectory); + fs.mkdirs(testOutputDirectory); + } + + if (hasOption("charset")) { + charset = Charset.forName(getOption("charset")); + } + + if (hasOption("testSplitSize") && hasOption("testSplitPct")) { + throw new OptionException(getCLIOption("testSplitPct"), "must have either split size or split percentage " + + "option, not BOTH"); + } + + if (hasOption("testSplitSize")) { + setTestSplitSize(Integer.parseInt(getOption("testSplitSize"))); + } + + if (hasOption("testSplitPct")) { + setTestSplitPct(Integer.parseInt(getOption("testSplitPct"))); + } + + if (hasOption("splitLocation")) { + setSplitLocation(Integer.parseInt(getOption("splitLocation"))); + } + + if (hasOption("randomSelectionSize")) { + setTestRandomSelectionSize(Integer.parseInt(getOption("randomSelectionSize"))); + } + + if (hasOption("randomSelectionPct")) { + setTestRandomSelectionPct(Integer.parseInt(getOption("randomSelectionPct"))); + } + + useSequence = hasOption("sequenceFiles"); + + } catch (OptionException e) { + log.error("Command-line option Exception", e); + CommandLineUtil.printHelp(getGroup()); + return false; + } + + validate(); + return true; + } + + /** + * Perform a split on directory specified by {@link #setInputDirectory(Path)} by calling {@link #splitFile(Path)} + * on each file found within that directory. + */ + public void splitDirectory() throws IOException, ClassNotFoundException, InterruptedException { + this.splitDirectory(inputDirectory); + } + + /** + * Perform a split on the specified directory by calling {@link #splitFile(Path)} on each file found within that + * directory. + */ + public void splitDirectory(Path inputDir) throws IOException, ClassNotFoundException, InterruptedException { + Configuration conf = getConf(); + splitDirectory(conf, inputDir); + } + + /* + * See also splitDirectory(Path inputDir) + * */ + public void splitDirectory(Configuration conf, Path inputDir) + throws IOException, ClassNotFoundException, InterruptedException { + FileSystem fs = inputDir.getFileSystem(conf); + if (fs.getFileStatus(inputDir) == null) { + throw new IOException(inputDir + " does not exist"); + } + if (!fs.getFileStatus(inputDir).isDir()) { + throw new IOException(inputDir + " is not a directory"); + } + + if (useMapRed) { + SplitInputJob.run(conf, inputDir, mapRedOutputDirectory, + keepPct, testRandomSelectionPct); + } else { + // input dir contains one file per category. + FileStatus[] fileStats = fs.listStatus(inputDir, PathFilters.logsCRCFilter()); + for (FileStatus inputFile : fileStats) { + if (!inputFile.isDir()) { + splitFile(inputFile.getPath()); + } + } + } + } + + /** + * Perform a split on the specified input file. Results will be written to files of the same name in the specified + * training and test output directories. The {@link #validate()} method is called prior to executing the split. + */ + public void splitFile(Path inputFile) throws IOException { + Configuration conf = getConf(); + FileSystem fs = inputFile.getFileSystem(conf); + if (fs.getFileStatus(inputFile) == null) { + throw new IOException(inputFile + " does not exist"); + } + if (fs.getFileStatus(inputFile).isDir()) { + throw new IOException(inputFile + " is a directory"); + } + + validate(); + + Path testOutputFile = new Path(testOutputDirectory, inputFile.getName()); + Path trainingOutputFile = new Path(trainingOutputDirectory, inputFile.getName()); + + int lineCount = countLines(fs, inputFile, charset); + + log.info("{} has {} lines", inputFile.getName(), lineCount); + + int testSplitStart = 0; + int testSplitSize = this.testSplitSize; // don't modify state + BitSet randomSel = null; + + if (testRandomSelectionPct > 0 || testRandomSelectionSize > 0) { + testSplitSize = this.testRandomSelectionSize; + + if (testRandomSelectionPct > 0) { + testSplitSize = Math.round(lineCount * testRandomSelectionPct / 100.0f); + } + log.info("{} test split size is {} based on random selection percentage {}", + inputFile.getName(), testSplitSize, testRandomSelectionPct); + long[] ridx = new long[testSplitSize]; + RandomSampler.sample(testSplitSize, lineCount - 1, testSplitSize, 0, ridx, 0, RandomUtils.getRandom()); + randomSel = new BitSet(lineCount); + for (long idx : ridx) { + randomSel.set((int) idx + 1); + } + } else { + if (testSplitPct > 0) { // calculate split size based on percentage + testSplitSize = Math.round(lineCount * testSplitPct / 100.0f); + log.info("{} test split size is {} based on percentage {}", + inputFile.getName(), testSplitSize, testSplitPct); + } else { + log.info("{} test split size is {}", inputFile.getName(), testSplitSize); + } + + if (splitLocation > 0) { // calculate start of split based on percentage + testSplitStart = Math.round(lineCount * splitLocation / 100.0f); + if (lineCount - testSplitStart < testSplitSize) { + // adjust split start downwards based on split size. + testSplitStart = lineCount - testSplitSize; + } + log.info("{} test split start is {} based on split location {}", + inputFile.getName(), testSplitStart, splitLocation); + } + + if (testSplitStart < 0) { + throw new IllegalArgumentException("test split size for " + inputFile + " is too large, it would produce an " + + "empty training set from the initial set of " + lineCount + " examples"); + } else if (lineCount - testSplitSize < testSplitSize) { + log.warn("Test set size for {} may be too large, {} is larger than the number of " + + "lines remaining in the training set: {}", + inputFile, testSplitSize, lineCount - testSplitSize); + } + } + int trainCount = 0; + int testCount = 0; + if (!useSequence) { + try (BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset)); + Writer trainingWriter = new OutputStreamWriter(fs.create(trainingOutputFile), charset); + Writer testWriter = new OutputStreamWriter(fs.create(testOutputFile), charset)){ + + String line; + int pos = 0; + while ((line = reader.readLine()) != null) { + pos++; + + Writer writer; + if (testRandomSelectionPct > 0) { // Randomly choose + writer = randomSel.get(pos) ? testWriter : trainingWriter; + } else { // Choose based on location + writer = pos > testSplitStart ? testWriter : trainingWriter; + } + + if (writer == testWriter) { + if (testCount >= testSplitSize) { + writer = trainingWriter; + } else { + testCount++; + } + } + if (writer == trainingWriter) { + trainCount++; + } + writer.write(line); + writer.write('\n'); + } + + } + } else { + try (SequenceFileIterator<Writable, Writable> iterator = + new SequenceFileIterator<>(inputFile, false, fs.getConf()); + SequenceFile.Writer trainingWriter = SequenceFile.createWriter(fs, fs.getConf(), trainingOutputFile, + iterator.getKeyClass(), iterator.getValueClass()); + SequenceFile.Writer testWriter = SequenceFile.createWriter(fs, fs.getConf(), testOutputFile, + iterator.getKeyClass(), iterator.getValueClass())) { + + int pos = 0; + while (iterator.hasNext()) { + pos++; + SequenceFile.Writer writer; + if (testRandomSelectionPct > 0) { // Randomly choose + writer = randomSel.get(pos) ? testWriter : trainingWriter; + } else { // Choose based on location + writer = pos > testSplitStart ? testWriter : trainingWriter; + } + + if (writer == testWriter) { + if (testCount >= testSplitSize) { + writer = trainingWriter; + } else { + testCount++; + } + } + if (writer == trainingWriter) { + trainCount++; + } + Pair<Writable, Writable> pair = iterator.next(); + writer.append(pair.getFirst(), pair.getSecond()); + } + + } + } + log.info("file: {}, input: {} train: {}, test: {} starting at {}", + inputFile.getName(), lineCount, trainCount, testCount, testSplitStart); + + // testing; + if (callback != null) { + callback.splitComplete(inputFile, lineCount, trainCount, testCount, testSplitStart); + } + } + + public int getTestSplitSize() { + return testSplitSize; + } + + public void setTestSplitSize(int testSplitSize) { + this.testSplitSize = testSplitSize; + } + + public int getTestSplitPct() { + return testSplitPct; + } + + /** + * Sets the percentage of the input data to allocate to the test split + * + * @param testSplitPct a value between 0 and 100 inclusive. + */ + public void setTestSplitPct(int testSplitPct) { + this.testSplitPct = testSplitPct; + } + + /** + * Sets the percentage of the input data to keep in a map reduce split input job + * + * @param keepPct a value between 0 and 100 inclusive. + */ + public void setKeepPct(int keepPct) { + this.keepPct = keepPct; + } + + /** + * Set to true to use map reduce to split the input + * + * @param useMapRed a boolean to indicate whether map reduce should be used + */ + public void setUseMapRed(boolean useMapRed) { + this.useMapRed = useMapRed; + } + + public void setMapRedOutputDirectory(Path mapRedOutputDirectory) { + this.mapRedOutputDirectory = mapRedOutputDirectory; + } + + public int getSplitLocation() { + return splitLocation; + } + + /** + * Set the location of the start of the test/training data split. Expressed as percentage of lines, for example + * 0 indicates that the test data should be taken from the start of the file, 100 indicates that the test data + * should be taken from the end of the input file, while 25 indicates that the test data should be taken from the + * first quarter of the file. + * <p/> + * This option is only relevant in cases where random selection is not employed + * + * @param splitLocation a value between 0 and 100 inclusive. + */ + public void setSplitLocation(int splitLocation) { + this.splitLocation = splitLocation; + } + + public Charset getCharset() { + return charset; + } + + /** + * Set the charset used to read and write files + */ + public void setCharset(Charset charset) { + this.charset = charset; + } + + public Path getInputDirectory() { + return inputDirectory; + } + + /** + * Set the directory from which input data will be read when the the {@link #splitDirectory()} method is invoked + */ + public void setInputDirectory(Path inputDir) { + this.inputDirectory = inputDir; + } + + public Path getTrainingOutputDirectory() { + return trainingOutputDirectory; + } + + /** + * Set the directory to which training data will be written. + */ + public void setTrainingOutputDirectory(Path trainingOutputDir) { + this.trainingOutputDirectory = trainingOutputDir; + } + + public Path getTestOutputDirectory() { + return testOutputDirectory; + } + + /** + * Set the directory to which test data will be written. + */ + public void setTestOutputDirectory(Path testOutputDir) { + this.testOutputDirectory = testOutputDir; + } + + public SplitCallback getCallback() { + return callback; + } + + /** + * Sets the callback used to inform the caller that an input file has been successfully split + */ + public void setCallback(SplitCallback callback) { + this.callback = callback; + } + + public int getTestRandomSelectionSize() { + return testRandomSelectionSize; + } + + /** + * Sets number of random input samples that will be saved to the test set. + */ + public void setTestRandomSelectionSize(int testRandomSelectionSize) { + this.testRandomSelectionSize = testRandomSelectionSize; + } + + public int getTestRandomSelectionPct() { + + return testRandomSelectionPct; + } + + /** + * Sets number of random input samples that will be saved to the test set as a percentage of the size of the + * input set. + * + * @param randomSelectionPct a value between 0 and 100 inclusive. + */ + public void setTestRandomSelectionPct(int randomSelectionPct) { + this.testRandomSelectionPct = randomSelectionPct; + } + + /** + * Validates that the current instance is in a consistent state + * + * @throws IllegalArgumentException if settings violate class invariants. + * @throws IOException if output directories do not exist or are not directories. + */ + public void validate() throws IOException { + Preconditions.checkArgument(testSplitSize >= 1 || testSplitSize == -1, + "Invalid testSplitSize: " + testSplitSize + ". Must be: testSplitSize >= 1 or testSplitSize = -1"); + Preconditions.checkArgument(splitLocation >= 0 && splitLocation <= 100 || splitLocation == -1, + "Invalid splitLocation percentage: " + splitLocation + ". Must be: 0 <= splitLocation <= 100 or splitLocation = -1"); + Preconditions.checkArgument(testSplitPct >= 0 && testSplitPct <= 100 || testSplitPct == -1, + "Invalid testSplitPct percentage: " + testSplitPct + ". Must be: 0 <= testSplitPct <= 100 or testSplitPct = -1"); + Preconditions.checkArgument(testRandomSelectionPct >= 0 && testRandomSelectionPct <= 100 + || testRandomSelectionPct == -1,"Invalid testRandomSelectionPct percentage: " + testRandomSelectionPct + + ". Must be: 0 <= testRandomSelectionPct <= 100 or testRandomSelectionPct = -1"); + + Preconditions.checkArgument(trainingOutputDirectory != null || useMapRed, + "No training output directory was specified"); + Preconditions.checkArgument(testOutputDirectory != null || useMapRed, "No test output directory was specified"); + + // only one of the following may be set, one must be set. + int count = 0; + if (testSplitSize > 0) { + count++; + } + if (testSplitPct > 0) { + count++; + } + if (testRandomSelectionSize > 0) { + count++; + } + if (testRandomSelectionPct > 0) { + count++; + } + + Preconditions.checkArgument(count == 1, "Exactly one of testSplitSize, testSplitPct, testRandomSelectionSize, " + + "testRandomSelectionPct should be set"); + + if (!useMapRed) { + Configuration conf = getConf(); + FileSystem fs = trainingOutputDirectory.getFileSystem(conf); + FileStatus trainingOutputDirStatus = fs.getFileStatus(trainingOutputDirectory); + Preconditions.checkArgument(trainingOutputDirStatus != null && trainingOutputDirStatus.isDir(), + "%s is not a directory", trainingOutputDirectory); + FileStatus testOutputDirStatus = fs.getFileStatus(testOutputDirectory); + Preconditions.checkArgument(testOutputDirStatus != null && testOutputDirStatus.isDir(), + "%s is not a directory", testOutputDirectory); + } + } + + /** + * Count the lines in the file specified as returned by {@code BufferedReader.readLine()} + * + * @param inputFile the file whose lines will be counted + * @param charset the charset of the file to read + * @return the number of lines in the input file. + * @throws IOException if there is a problem opening or reading the file. + */ + public static int countLines(FileSystem fs, Path inputFile, Charset charset) throws IOException { + int lineCount = 0; + try (BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset))){ + while (reader.readLine() != null) { + lineCount++; + } + } + return lineCount; + } + + /** + * Used to pass information back to a caller once a file has been split without the need for a data object + */ + public interface SplitCallback { + void splitComplete(Path inputFile, int lineCount, int trainCount, int testCount, int testSplitStart); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java new file mode 100644 index 0000000..4a1ff86 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java @@ -0,0 +1,213 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils; + +import java.io.IOException; +import java.io.Serializable; +import java.util.Random; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.io.WritableComparator; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; + +/** + * Class which implements a map reduce version of SplitInput. + * This class takes a SequenceFile input, e.g. a set of training data + * for a learning algorithm, downsamples it, applies a random + * permutation and splits it into test and training sets + */ +public final class SplitInputJob { + + private static final String DOWNSAMPLING_FACTOR = "SplitInputJob.downsamplingFactor"; + private static final String RANDOM_SELECTION_PCT = "SplitInputJob.randomSelectionPct"; + private static final String TRAINING_TAG = "training"; + private static final String TEST_TAG = "test"; + + private SplitInputJob() {} + + /** + * Run job to downsample, randomly permute and split data into test and + * training sets. This job takes a SequenceFile as input and outputs two + * SequenceFiles test-r-00000 and training-r-00000 which contain the test and + * training sets respectively + * + * @param initialConf + * Initial configuration + * @param inputPath + * path to input data SequenceFile + * @param outputPath + * path for output data SequenceFiles + * @param keepPct + * percentage of key value pairs in input to keep. The rest are + * discarded + * @param randomSelectionPercent + * percentage of key value pairs to allocate to test set. Remainder + * are allocated to training set + */ + @SuppressWarnings("rawtypes") + public static void run(Configuration initialConf, Path inputPath, + Path outputPath, int keepPct, float randomSelectionPercent) + throws IOException, ClassNotFoundException, InterruptedException { + + int downsamplingFactor = (int) (100.0 / keepPct); + initialConf.setInt(DOWNSAMPLING_FACTOR, downsamplingFactor); + initialConf.setFloat(RANDOM_SELECTION_PCT, randomSelectionPercent); + + // Determine class of keys and values + FileSystem fs = FileSystem.get(initialConf); + + SequenceFileDirIterator<? extends WritableComparable, Writable> iterator = + new SequenceFileDirIterator<>(inputPath, + PathType.LIST, PathFilters.partFilter(), null, false, fs.getConf()); + Class<? extends WritableComparable> keyClass; + Class<? extends Writable> valueClass; + if (iterator.hasNext()) { + Pair<? extends WritableComparable, Writable> pair = iterator.next(); + keyClass = pair.getFirst().getClass(); + valueClass = pair.getSecond().getClass(); + } else { + throw new IllegalStateException("Couldn't determine class of the input values"); + } + + Job job = new Job(new Configuration(initialConf)); + + MultipleOutputs.addNamedOutput(job, TRAINING_TAG, SequenceFileOutputFormat.class, keyClass, valueClass); + MultipleOutputs.addNamedOutput(job, TEST_TAG, SequenceFileOutputFormat.class, keyClass, valueClass); + job.setJarByClass(SplitInputJob.class); + FileInputFormat.addInputPath(job, inputPath); + FileOutputFormat.setOutputPath(job, outputPath); + job.setNumReduceTasks(1); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(SplitInputMapper.class); + job.setReducerClass(SplitInputReducer.class); + job.setSortComparatorClass(SplitInputComparator.class); + job.setOutputKeyClass(keyClass); + job.setOutputValueClass(valueClass); + job.submit(); + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + } + + /** Mapper which downsamples the input by downsamplingFactor */ + public static class SplitInputMapper extends + Mapper<WritableComparable<?>, Writable, WritableComparable<?>, Writable> { + + private int downsamplingFactor; + + @Override + public void setup(Context ctx) { + downsamplingFactor = ctx.getConfiguration().getInt(DOWNSAMPLING_FACTOR, 1); + } + + /** Only run map() for one out of every downsampleFactor inputs */ + @Override + public void run(Context context) throws IOException, InterruptedException { + setup(context); + int i = 0; + while (context.nextKeyValue()) { + if (i % downsamplingFactor == 0) { + map(context.getCurrentKey(), context.getCurrentValue(), context); + } + i++; + } + cleanup(context); + } + + } + + /** Reducer which uses MultipleOutputs to randomly allocate key value pairs between test and training outputs */ + public static class SplitInputReducer extends + Reducer<WritableComparable<?>, Writable, WritableComparable<?>, Writable> { + + private MultipleOutputs multipleOutputs; + private final Random rnd = RandomUtils.getRandom(); + private float randomSelectionPercent; + + @Override + protected void setup(Context ctx) throws IOException { + randomSelectionPercent = ctx.getConfiguration().getFloat(RANDOM_SELECTION_PCT, 0); + multipleOutputs = new MultipleOutputs(ctx); + } + + /** + * Randomly allocate key value pairs between test and training sets. + * randomSelectionPercent of the pairs will go to the test set. + */ + @Override + protected void reduce(WritableComparable<?> key, Iterable<Writable> values, + Context context) throws IOException, InterruptedException { + for (Writable value : values) { + if (rnd.nextInt(100) < randomSelectionPercent) { + multipleOutputs.write(TEST_TAG, key, value); + } else { + multipleOutputs.write(TRAINING_TAG, key, value); + } + } + + } + + @Override + protected void cleanup(Context context) throws IOException { + try { + multipleOutputs.close(); + } catch (InterruptedException e) { + throw new IOException(e); + } + } + + } + + /** Randomly permute key value pairs */ + public static class SplitInputComparator extends WritableComparator implements Serializable { + + private final Random rnd = RandomUtils.getRandom(); + + protected SplitInputComparator() { + super(WritableComparable.class); + } + + @Override + public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) { + if (rnd.nextBoolean()) { + return 1; + } else { + return -1; + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/AbstractClusterWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/AbstractClusterWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/AbstractClusterWriter.java new file mode 100644 index 0000000..ac884d0 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/AbstractClusterWriter.java @@ -0,0 +1,160 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.clustering; + +import java.io.IOException; +import java.io.Writer; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.commons.lang3.StringUtils; +import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import com.google.common.collect.Lists; + +/** + * Base class for implementing ClusterWriter + */ +public abstract class AbstractClusterWriter implements ClusterWriter { + + private static final Logger log = LoggerFactory.getLogger(AbstractClusterWriter.class); + + protected final Writer writer; + protected final Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints; + protected final DistanceMeasure measure; + + /** + * + * @param writer The underlying {@link java.io.Writer} to use + * @param clusterIdToPoints The map between cluster ids {@link org.apache.mahout.clustering.Cluster#getId()} and the + * points in the cluster + * @param measure The {@link org.apache.mahout.common.distance.DistanceMeasure} used to calculate the distance. + * Some writers may wish to use it for calculating weights for display. May be null. + */ + protected AbstractClusterWriter(Writer writer, Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints, + DistanceMeasure measure) { + this.writer = writer; + this.clusterIdToPoints = clusterIdToPoints; + this.measure = measure; + } + + protected Writer getWriter() { + return writer; + } + + protected Map<Integer, List<WeightedPropertyVectorWritable>> getClusterIdToPoints() { + return clusterIdToPoints; + } + + public static String getTopFeatures(Vector vector, String[] dictionary, int numTerms) { + + StringBuilder sb = new StringBuilder(100); + + for (Pair<String, Double> item : getTopPairs(vector, dictionary, numTerms)) { + String term = item.getFirst(); + sb.append("\n\t\t"); + sb.append(StringUtils.rightPad(term, 40)); + sb.append("=>"); + sb.append(StringUtils.leftPad(item.getSecond().toString(), 20)); + } + return sb.toString(); + } + + public static String getTopTerms(Vector vector, String[] dictionary, int numTerms) { + + StringBuilder sb = new StringBuilder(100); + + for (Pair<String, Double> item : getTopPairs(vector, dictionary, numTerms)) { + String term = item.getFirst(); + sb.append(term).append('_'); + } + sb.deleteCharAt(sb.length() - 1); + return sb.toString(); + } + + @Override + public long write(Iterable<ClusterWritable> iterable) throws IOException { + return write(iterable, Long.MAX_VALUE); + } + + @Override + public void close() throws IOException { + writer.close(); + } + + @Override + public long write(Iterable<ClusterWritable> iterable, long maxDocs) throws IOException { + long result = 0; + Iterator<ClusterWritable> iterator = iterable.iterator(); + while (result < maxDocs && iterator.hasNext()) { + write(iterator.next()); + result++; + } + return result; + } + + private static Collection<Pair<String, Double>> getTopPairs(Vector vector, String[] dictionary, int numTerms) { + List<TermIndexWeight> vectorTerms = Lists.newArrayList(); + + for (Vector.Element elt : vector.nonZeroes()) { + vectorTerms.add(new TermIndexWeight(elt.index(), elt.get())); + } + + // Sort results in reverse order (ie weight in descending order) + Collections.sort(vectorTerms, new Comparator<TermIndexWeight>() { + @Override + public int compare(TermIndexWeight one, TermIndexWeight two) { + return Double.compare(two.weight, one.weight); + } + }); + + Collection<Pair<String, Double>> topTerms = Lists.newLinkedList(); + + for (int i = 0; i < vectorTerms.size() && i < numTerms; i++) { + int index = vectorTerms.get(i).index; + String dictTerm = dictionary[index]; + if (dictTerm == null) { + log.error("Dictionary entry missing for {}", index); + continue; + } + topTerms.add(new Pair<>(dictTerm, vectorTerms.get(i).weight)); + } + + return topTerms; + } + + private static class TermIndexWeight { + private final int index; + private final double weight; + + TermIndexWeight(int index, double weight) { + this.index = index; + this.weight = weight; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/CSVClusterWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/CSVClusterWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/CSVClusterWriter.java new file mode 100644 index 0000000..7269016 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/CSVClusterWriter.java @@ -0,0 +1,69 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.clustering; + +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.Vector; + +import java.io.IOException; +import java.io.Writer; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Format is adjacency style as put forth at http://gephi.org/users/supported-graph-formats/csv-format/, the centroid + * is the first element and all the rest of the row are the points in that cluster + * + **/ +public class CSVClusterWriter extends AbstractClusterWriter { + + private static final Pattern VEC_PATTERN = Pattern.compile("\\{|\\:|\\,|\\}"); + + public CSVClusterWriter(Writer writer, Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints, + DistanceMeasure measure) { + super(writer, clusterIdToPoints, measure); + } + + @Override + public void write(ClusterWritable clusterWritable) throws IOException { + StringBuilder line = new StringBuilder(); + Cluster cluster = clusterWritable.getValue(); + line.append(cluster.getId()); + List<WeightedPropertyVectorWritable> points = getClusterIdToPoints().get(cluster.getId()); + if (points != null) { + for (WeightedPropertyVectorWritable point : points) { + Vector theVec = point.getVector(); + line.append(','); + if (theVec instanceof NamedVector) { + line.append(((NamedVector)theVec).getName()); + } else { + String vecStr = theVec.asFormatString(); + //do some basic manipulations for display + vecStr = VEC_PATTERN.matcher(vecStr).replaceAll("_"); + line.append(vecStr); + } + } + getWriter().append(line).append("\n"); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java new file mode 100644 index 0000000..75b5ded --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumper.java @@ -0,0 +1,328 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.clustering; + +import java.io.File; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.clustering.cdbw.CDbwEvaluator; +import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable; +import org.apache.mahout.clustering.evaluation.ClusterEvaluator; +import org.apache.mahout.clustering.evaluation.RepresentativePointsDriver; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.apache.mahout.utils.vectors.VectorHelper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class ClusterDumper extends AbstractJob { + + public static final String SAMPLE_POINTS = "samplePoints"; + DistanceMeasure measure; + + public enum OUTPUT_FORMAT { + TEXT, + CSV, + GRAPH_ML, + JSON, + } + + public static final String DICTIONARY_TYPE_OPTION = "dictionaryType"; + public static final String DICTIONARY_OPTION = "dictionary"; + public static final String POINTS_DIR_OPTION = "pointsDir"; + public static final String NUM_WORDS_OPTION = "numWords"; + public static final String SUBSTRING_OPTION = "substring"; + public static final String EVALUATE_CLUSTERS = "evaluate"; + + public static final String OUTPUT_FORMAT_OPT = "outputFormat"; + + private static final Logger log = LoggerFactory.getLogger(ClusterDumper.class); + private Path seqFileDir; + private Path pointsDir; + private long maxPointsPerCluster = Long.MAX_VALUE; + private String termDictionary; + private String dictionaryFormat; + private int subString = Integer.MAX_VALUE; + private int numTopFeatures = 10; + private Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints; + private OUTPUT_FORMAT outputFormat = OUTPUT_FORMAT.TEXT; + private boolean runEvaluation; + + public ClusterDumper(Path seqFileDir, Path pointsDir) { + this.seqFileDir = seqFileDir; + this.pointsDir = pointsDir; + init(); + } + + public ClusterDumper() { + setConf(new Configuration()); + } + + public static void main(String[] args) throws Exception { + new ClusterDumper().run(args); + } + + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(OUTPUT_FORMAT_OPT, "of", "The optional output format for the results. Options: TEXT, CSV, JSON or GRAPH_ML", + "TEXT"); + addOption(SUBSTRING_OPTION, "b", "The number of chars of the asFormatString() to print"); + addOption(NUM_WORDS_OPTION, "n", "The number of top terms to print"); + addOption(POINTS_DIR_OPTION, "p", + "The directory containing points sequence files mapping input vectors to their cluster. " + + "If specified, then the program will output the points associated with a cluster"); + addOption(SAMPLE_POINTS, "sp", "Specifies the maximum number of points to include _per_ cluster. The default " + + "is to include all points"); + addOption(DICTIONARY_OPTION, "d", "The dictionary file"); + addOption(DICTIONARY_TYPE_OPTION, "dt", "The dictionary file type (text|sequencefile)", "text"); + addOption(buildOption(EVALUATE_CLUSTERS, "e", "Run ClusterEvaluator and CDbwEvaluator over the input. " + + "The output will be appended to the rest of the output at the end.", false, false, null)); + addOption(DefaultOptionCreator.distanceMeasureOption().create()); + + // output is optional, will print to System.out per default + if (parseArguments(args, false, true) == null) { + return -1; + } + + seqFileDir = getInputPath(); + if (hasOption(POINTS_DIR_OPTION)) { + pointsDir = new Path(getOption(POINTS_DIR_OPTION)); + } + outputFile = getOutputFile(); + if (hasOption(SUBSTRING_OPTION)) { + int sub = Integer.parseInt(getOption(SUBSTRING_OPTION)); + if (sub >= 0) { + subString = sub; + } + } + termDictionary = getOption(DICTIONARY_OPTION); + dictionaryFormat = getOption(DICTIONARY_TYPE_OPTION); + if (hasOption(NUM_WORDS_OPTION)) { + numTopFeatures = Integer.parseInt(getOption(NUM_WORDS_OPTION)); + } + if (hasOption(OUTPUT_FORMAT_OPT)) { + outputFormat = OUTPUT_FORMAT.valueOf(getOption(OUTPUT_FORMAT_OPT)); + } + if (hasOption(SAMPLE_POINTS)) { + maxPointsPerCluster = Long.parseLong(getOption(SAMPLE_POINTS)); + } else { + maxPointsPerCluster = Long.MAX_VALUE; + } + runEvaluation = hasOption(EVALUATE_CLUSTERS); + String distanceMeasureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); + measure = ClassUtils.instantiateAs(distanceMeasureClass, DistanceMeasure.class); + + init(); + printClusters(null); + return 0; + } + + public void printClusters(String[] dictionary) throws Exception { + Configuration conf = new Configuration(); + + if (this.termDictionary != null) { + if ("text".equals(dictionaryFormat)) { + dictionary = VectorHelper.loadTermDictionary(new File(this.termDictionary)); + } else if ("sequencefile".equals(dictionaryFormat)) { + dictionary = VectorHelper.loadTermDictionary(conf, this.termDictionary); + } else { + throw new IllegalArgumentException("Invalid dictionary format"); + } + } + + Writer writer; + boolean shouldClose; + if (this.outputFile == null) { + shouldClose = false; + writer = new OutputStreamWriter(System.out, Charsets.UTF_8); + } else { + shouldClose = true; + if (outputFile.getName().startsWith("s3n://")) { + Path p = outputPath; + FileSystem fs = FileSystem.get(p.toUri(), conf); + writer = new OutputStreamWriter(fs.create(p), Charsets.UTF_8); + } else { + Files.createParentDirs(outputFile); + writer = Files.newWriter(this.outputFile, Charsets.UTF_8); + } + } + ClusterWriter clusterWriter = createClusterWriter(writer, dictionary); + try { + long numWritten = clusterWriter.write(new SequenceFileDirValueIterable<ClusterWritable>(new Path(seqFileDir, + "part-*"), PathType.GLOB, conf)); + + writer.flush(); + if (runEvaluation) { + HadoopUtil.delete(conf, new Path("tmp/representative")); + int numIters = 5; + RepresentativePointsDriver.main(new String[]{ + "--input", seqFileDir.toString(), + "--output", "tmp/representative", + "--clusteredPoints", pointsDir.toString(), + "--distanceMeasure", measure.getClass().getName(), + "--maxIter", String.valueOf(numIters) + }); + conf.set(RepresentativePointsDriver.DISTANCE_MEASURE_KEY, measure.getClass().getName()); + conf.set(RepresentativePointsDriver.STATE_IN_KEY, "tmp/representative/representativePoints-" + numIters); + ClusterEvaluator ce = new ClusterEvaluator(conf, seqFileDir); + writer.append("\n"); + writer.append("Inter-Cluster Density: ").append(String.valueOf(ce.interClusterDensity())).append("\n"); + writer.append("Intra-Cluster Density: ").append(String.valueOf(ce.intraClusterDensity())).append("\n"); + CDbwEvaluator cdbw = new CDbwEvaluator(conf, seqFileDir); + writer.append("CDbw Inter-Cluster Density: ").append(String.valueOf(cdbw.interClusterDensity())).append("\n"); + writer.append("CDbw Intra-Cluster Density: ").append(String.valueOf(cdbw.intraClusterDensity())).append("\n"); + writer.append("CDbw Separation: ").append(String.valueOf(cdbw.separation())).append("\n"); + writer.flush(); + } + log.info("Wrote {} clusters", numWritten); + } finally { + if (shouldClose) { + Closeables.close(clusterWriter, false); + } else { + if (clusterWriter instanceof GraphMLClusterWriter) { + clusterWriter.close(); + } + } + } + } + + ClusterWriter createClusterWriter(Writer writer, String[] dictionary) throws IOException { + ClusterWriter result; + + switch (outputFormat) { + case TEXT: + result = new ClusterDumperWriter(writer, clusterIdToPoints, measure, numTopFeatures, dictionary, subString); + break; + case CSV: + result = new CSVClusterWriter(writer, clusterIdToPoints, measure); + break; + case GRAPH_ML: + result = new GraphMLClusterWriter(writer, clusterIdToPoints, measure, numTopFeatures, dictionary, subString); + break; + case JSON: + result = new JsonClusterWriter(writer, clusterIdToPoints, measure, numTopFeatures, dictionary); + break; + default: + throw new IllegalStateException("Unknown outputformat: " + outputFormat); + } + return result; + } + + /** + * Convenience function to set the output format during testing. + */ + public void setOutputFormat(OUTPUT_FORMAT of) { + outputFormat = of; + } + + private void init() { + if (this.pointsDir != null) { + Configuration conf = new Configuration(); + // read in the points + clusterIdToPoints = readPoints(this.pointsDir, maxPointsPerCluster, conf); + } else { + clusterIdToPoints = Collections.emptyMap(); + } + } + + + public int getSubString() { + return subString; + } + + public void setSubString(int subString) { + this.subString = subString; + } + + public Map<Integer, List<WeightedPropertyVectorWritable>> getClusterIdToPoints() { + return clusterIdToPoints; + } + + public String getTermDictionary() { + return termDictionary; + } + + public void setTermDictionary(String termDictionary, String dictionaryType) { + this.termDictionary = termDictionary; + this.dictionaryFormat = dictionaryType; + } + + public void setNumTopFeatures(int num) { + this.numTopFeatures = num; + } + + public int getNumTopFeatures() { + return this.numTopFeatures; + } + + public long getMaxPointsPerCluster() { + return maxPointsPerCluster; + } + + public void setMaxPointsPerCluster(long maxPointsPerCluster) { + this.maxPointsPerCluster = maxPointsPerCluster; + } + + public static Map<Integer, List<WeightedPropertyVectorWritable>> readPoints(Path pointsPathDir, + long maxPointsPerCluster, + Configuration conf) { + Map<Integer, List<WeightedPropertyVectorWritable>> result = new TreeMap<>(); + for (Pair<IntWritable, WeightedPropertyVectorWritable> record + : new SequenceFileDirIterable<IntWritable, WeightedPropertyVectorWritable>(pointsPathDir, PathType.LIST, + PathFilters.logsCRCFilter(), conf)) { + // value is the cluster id as an int, key is the name/id of the + // vector, but that doesn't matter because we only care about printing it + //String clusterId = value.toString(); + int keyValue = record.getFirst().get(); + List<WeightedPropertyVectorWritable> pointList = result.get(keyValue); + if (pointList == null) { + pointList = new ArrayList<>(); + result.put(keyValue, pointList); + } + if (pointList.size() < maxPointsPerCluster) { + pointList.add(record.getSecond()); + } + } + return result; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumperWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumperWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumperWriter.java new file mode 100644 index 0000000..31858c4 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterDumperWriter.java @@ -0,0 +1,100 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.clustering; + +import org.apache.hadoop.io.Text; +import org.apache.mahout.clustering.AbstractCluster; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.distance.DistanceMeasure; + +import java.io.IOException; +import java.io.Writer; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * Implements a {@link ClusterWriter} that outputs in the format used by ClusterDumper in Mahout 0.5 + */ +public class ClusterDumperWriter extends AbstractClusterWriter { + + private final int subString; + private final String[] dictionary; + private final int numTopFeatures; + + public ClusterDumperWriter(Writer writer, Map<Integer,List<WeightedPropertyVectorWritable>> clusterIdToPoints, + DistanceMeasure measure, int numTopFeatures, String[] dictionary, int subString) { + super(writer, clusterIdToPoints, measure); + this.numTopFeatures = numTopFeatures; + this.dictionary = dictionary; + this.subString = subString; + } + + @Override + public void write(ClusterWritable clusterWritable) throws IOException { + Cluster cluster = clusterWritable.getValue(); + String fmtStr = cluster.asFormatString(dictionary); + Writer writer = getWriter(); + if (subString > 0 && fmtStr.length() > subString) { + writer.write(':'); + writer.write(fmtStr, 0, Math.min(subString, fmtStr.length())); + } else { + writer.write(fmtStr); + } + + writer.write('\n'); + + if (dictionary != null) { + String topTerms = getTopFeatures(clusterWritable.getValue().getCenter(), dictionary, numTopFeatures); + writer.write("\tTop Terms: "); + writer.write(topTerms); + writer.write('\n'); + } + + Map<Integer,List<WeightedPropertyVectorWritable>> clusterIdToPoints = getClusterIdToPoints(); + List<WeightedPropertyVectorWritable> points = clusterIdToPoints.get(clusterWritable.getValue().getId()); + if (points != null) { + writer.write("\tWeight : [props - optional]: Point:\n\t"); + for (Iterator<WeightedPropertyVectorWritable> iterator = points.iterator(); iterator.hasNext();) { + WeightedPropertyVectorWritable point = iterator.next(); + writer.write(String.valueOf(point.getWeight())); + Map<Text,Text> map = point.getProperties(); + // map can be null since empty maps when written are returned as null + writer.write(" : ["); + if (map != null) { + for (Map.Entry<Text,Text> entry : map.entrySet()) { + writer.write(entry.getKey().toString()); + writer.write("="); + writer.write(entry.getValue().toString()); + } + } + writer.write("]"); + + writer.write(": "); + + writer.write(AbstractCluster.formatVector(point.getVector(), dictionary)); + if (iterator.hasNext()) { + writer.write("\n\t"); + } + } + writer.write('\n'); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterWriter.java new file mode 100644 index 0000000..70f8f6f --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/ClusterWriter.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.clustering; + +import java.io.Closeable; +import java.io.IOException; + +import org.apache.mahout.clustering.iterator.ClusterWritable; + +/** + * Writes out clusters + */ +public interface ClusterWriter extends Closeable { + + /** + * Write all values in the Iterable to the output + * + * @param iterable The {@link Iterable} to loop over + * @return the number of docs written + * @throws java.io.IOException if there was a problem writing + */ + long write(Iterable<ClusterWritable> iterable) throws IOException; + + /** + * Write out a Cluster + */ + void write(ClusterWritable clusterWritable) throws IOException; + + /** + * Write the first {@code maxDocs} to the output. + * + * @param iterable The {@link Iterable} to loop over + * @param maxDocs the maximum number of docs to write + * @return The number of docs written + * @throws IOException if there was a problem writing + */ + long write(Iterable<ClusterWritable> iterable, long maxDocs) throws IOException; +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/GraphMLClusterWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/GraphMLClusterWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/GraphMLClusterWriter.java new file mode 100644 index 0000000..25e8f3b --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/GraphMLClusterWriter.java @@ -0,0 +1,216 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.clustering; + +import java.io.IOException; +import java.io.Writer; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.regex.Pattern; + +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.StringUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.Vector; + +/** + * GraphML -- see http://gephi.org/users/supported-graph-formats/graphml-format/ + */ +public class GraphMLClusterWriter extends AbstractClusterWriter { + + private static final Pattern VEC_PATTERN = Pattern.compile("\\{|\\:|\\,|\\}"); + private final Map<Integer, Color> colors = new HashMap<>(); + private Color lastClusterColor; + private float lastX; + private float lastY; + private Random random; + private int posStep; + private final String[] dictionary; + private final int numTopFeatures; + private final int subString; + + public GraphMLClusterWriter(Writer writer, Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints, + DistanceMeasure measure, int numTopFeatures, String[] dictionary, int subString) + throws IOException { + super(writer, clusterIdToPoints, measure); + this.dictionary = dictionary; + this.numTopFeatures = numTopFeatures; + this.subString = subString; + init(writer); + } + + private void init(Writer writer) throws IOException { + writer.append("<?xml version=\"1.0\" encoding=\"UTF-8\"?>"); + writer.append("<graphml xmlns=\"http://graphml.graphdrawing.org/xmlns\"\n" + + "xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n" + + "xsi:schemaLocation=\"http://graphml.graphdrawing.org/xmlns\n" + + "http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd\">"); + //support rgb + writer.append("<key attr.name=\"r\" attr.type=\"int\" for=\"node\" id=\"r\"/>\n" + + "<key attr.name=\"g\" attr.type=\"int\" for=\"node\" id=\"g\"/>\n" + + "<key attr.name=\"b\" attr.type=\"int\" for=\"node\" id=\"b\"/>" + + "<key attr.name=\"size\" attr.type=\"int\" for=\"node\" id=\"size\"/>" + + "<key attr.name=\"weight\" attr.type=\"float\" for=\"edge\" id=\"weight\"/>" + + "<key attr.name=\"x\" attr.type=\"float\" for=\"node\" id=\"x\"/>" + + "<key attr.name=\"y\" attr.type=\"float\" for=\"node\" id=\"y\"/>"); + writer.append("<graph edgedefault=\"undirected\">"); + lastClusterColor = new Color(); + posStep = (int) (0.1 * clusterIdToPoints.size()) + 100; + random = RandomUtils.getRandom(); + } + + /* + <?xml version="1.0" encoding="UTF-8"?> + <graphml xmlns="http://graphml.graphdrawing.org/xmlns" + xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://graphml.graphdrawing.org/xmlns + http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd"> + <graph id="G" edgedefault="undirected"> + <node id="n0"/> + <node id="n1"/> + <edge id="e1" source="n0" target="n1"/> + </graph> + </graphml> + */ + + @Override + public void write(ClusterWritable clusterWritable) throws IOException { + StringBuilder line = new StringBuilder(); + Cluster cluster = clusterWritable.getValue(); + Color rgb = getColor(cluster.getId()); + + String topTerms = ""; + if (dictionary != null) { + topTerms = getTopTerms(cluster.getCenter(), dictionary, numTopFeatures); + } + String clusterLabel = String.valueOf(cluster.getId()) + '_' + topTerms; + //do some positioning so that items are visible and grouped together + //TODO: put in a real layout algorithm + float x = lastX + 1000; + float y = lastY; + if (x > (1000 + posStep)) { + y = lastY + 1000; + x = 0; + } + + line.append(createNode(clusterLabel, rgb, x, y)); + List<WeightedPropertyVectorWritable> points = clusterIdToPoints.get(cluster.getId()); + if (points != null) { + for (WeightedVectorWritable point : points) { + Vector theVec = point.getVector(); + double distance = 1; + if (measure != null) { + //scale the distance + distance = measure.distance(cluster.getCenter().getLengthSquared(), cluster.getCenter(), theVec) * 500; + } + String vecStr; + int angle = random.nextInt(360); //pick an angle at random and then scale along that angle + double angleRads = Math.toRadians(angle); + + float targetX = x + (float) (distance * Math.cos(angleRads)); + float targetY = y + (float) (distance * Math.sin(angleRads)); + if (theVec instanceof NamedVector) { + vecStr = ((NamedVector) theVec).getName(); + } else { + vecStr = theVec.asFormatString(); + //do some basic manipulations for display + vecStr = VEC_PATTERN.matcher(vecStr).replaceAll("_"); + } + if (subString > 0 && vecStr.length() > subString) { + vecStr = vecStr.substring(0, subString); + } + line.append(createNode(vecStr, rgb, targetX, targetY)); + line.append(createEdge(clusterLabel, vecStr, distance)); + } + } + lastClusterColor = rgb; + lastX = x; + lastY = y; + getWriter().append(line).append("\n"); + } + + private Color getColor(int clusterId) { + Color result = colors.get(clusterId); + if (result == null) { + result = new Color(); + //there is probably some better way to color a graph + int incR = 0; + int incG = 0; + int incB = 0; + if (lastClusterColor.r + 20 < 256 && lastClusterColor.g + 20 < 256 && lastClusterColor.b + 20 < 256) { + incR = 20; + incG = 0; + incB = 0; + } else if (lastClusterColor.r + 20 >= 256 && lastClusterColor.g + 20 < 256 && lastClusterColor.b + 20 < 256) { + incG = 20; + incB = 0; + } else if (lastClusterColor.r + 20 >= 256 && lastClusterColor.g + 20 >= 256 && lastClusterColor.b + 20 < 256) { + incB = 20; + } else { + incR += 3; + incG += 3; + incR += 3; + } + result.r = (lastClusterColor.r + incR) % 256; + result.g = (lastClusterColor.g + incG) % 256; + result.b = (lastClusterColor.b + incB) % 256; + colors.put(clusterId, result); + } + return result; + } + + private static String createEdge(String left, String right, double distance) { + left = StringUtils.escapeXML(left); + right = StringUtils.escapeXML(right); + return "<edge id=\"" + left + '_' + right + "\" source=\"" + left + "\" target=\"" + right + "\">" + + "<data key=\"weight\">" + distance + "</data></edge>"; + } + + private static String createNode(String s, Color rgb, float x, float y) { + return "<node id=\"" + StringUtils.escapeXML(s) + "\"><data key=\"r\">" + rgb.r + + "</data>" + + "<data key=\"g\">" + rgb.g + + "</data>" + + "<data key=\"b\">" + rgb.b + + "</data>" + + "<data key=\"x\">" + x + + "</data>" + + "<data key=\"y\">" + y + + "</data>" + + "</node>"; + } + + @Override + public void close() throws IOException { + getWriter().append("</graph>").append("</graphml>"); + super.close(); + } + + private static class Color { + int r; + int g; + int b; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java new file mode 100644 index 0000000..d564a73 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/clustering/JsonClusterWriter.java @@ -0,0 +1,188 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.utils.clustering; + +import java.io.IOException; +import java.io.Writer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +import org.apache.mahout.clustering.AbstractCluster; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.Vector; +import org.codehaus.jackson.map.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Dump cluster info to JSON formatted lines. Heavily inspired by + * ClusterDumperWriter.java and CSVClusterWriter.java + * + */ +public class JsonClusterWriter extends AbstractClusterWriter { + private final String[] dictionary; + private final int numTopFeatures; + private final ObjectMapper jxn; + + private static final Logger log = LoggerFactory.getLogger(JsonClusterWriter.class); + private static final Pattern VEC_PATTERN = Pattern.compile("\\{|\\:|\\,|\\}"); + + public JsonClusterWriter(Writer writer, + Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints, + DistanceMeasure measure, int numTopFeatures, String[] dictionary) { + super(writer, clusterIdToPoints, measure); + this.numTopFeatures = numTopFeatures; + this.dictionary = dictionary; + jxn = new ObjectMapper(); + } + + /** + * Generate HashMap with cluster info and write as a single JSON formatted + * line + */ + @Override + public void write(ClusterWritable clusterWritable) throws IOException { + Map<String, Object> res = new HashMap<>(); + + // get top terms + if (dictionary != null) { + List<Object> topTerms = getTopFeaturesList(clusterWritable.getValue() + .getCenter(), dictionary, numTopFeatures); + res.put("top_terms", topTerms); + } else { + res.put("top_terms", new ArrayList<>()); + } + + // get human-readable cluster representation + Cluster cluster = clusterWritable.getValue(); + res.put("cluster_id", cluster.getId()); + + if (dictionary != null) { + Map<String,Object> fmtStr = cluster.asJson(dictionary); + res.put("cluster", fmtStr); + + // get points + List<Object> points = getPoints(cluster, dictionary); + res.put("points", points); + } else { + res.put("cluster", new HashMap<>()); + res.put("points", new ArrayList<>()); + } + + // write JSON + Writer writer = getWriter(); + writer.write(jxn.writeValueAsString(res) + "\n"); + } + + /** + * Create a List of HashMaps containing top terms information + * + * @return List<Object> + */ + public List<Object> getTopFeaturesList(Vector vector, String[] dictionary, + int numTerms) { + + List<TermIndexWeight> vectorTerms = new ArrayList<>(); + + for (Vector.Element elt : vector.nonZeroes()) { + vectorTerms.add(new TermIndexWeight(elt.index(), elt.get())); + } + + // Sort results in reverse order (i.e. weight in descending order) + Collections.sort(vectorTerms, new Comparator<TermIndexWeight>() { + @Override + public int compare(TermIndexWeight one, TermIndexWeight two) { + return Double.compare(two.weight, one.weight); + } + }); + + List<Object> topTerms = new ArrayList<>(); + + for (int i = 0; i < vectorTerms.size() && i < numTerms; i++) { + int index = vectorTerms.get(i).index; + String dictTerm = dictionary[index]; + if (dictTerm == null) { + log.error("Dictionary entry missing for {}", index); + continue; + } + Map<String, Object> term_entry = new HashMap<>(); + term_entry.put(dictTerm, vectorTerms.get(i).weight); + topTerms.add(term_entry); + } + + return topTerms; + } + + /** + * Create a List of HashMaps containing Vector point information + * + * @return List<Object> + */ + public List<Object> getPoints(Cluster cluster, String[] dictionary) { + List<Object> vectorObjs = new ArrayList<>(); + List<WeightedPropertyVectorWritable> points = getClusterIdToPoints().get( + cluster.getId()); + + if (points != null) { + for (WeightedPropertyVectorWritable point : points) { + Map<String, Object> entry = new HashMap<>(); + Vector theVec = point.getVector(); + if (theVec instanceof NamedVector) { + entry.put("vector_name", ((NamedVector) theVec).getName()); + } else { + String vecStr = theVec.asFormatString(); + // do some basic manipulations for display + vecStr = VEC_PATTERN.matcher(vecStr).replaceAll("_"); + entry.put("vector_name", vecStr); + } + entry.put("weight", String.valueOf(point.getWeight())); + try { + entry.put("point", + AbstractCluster.formatVectorAsJson(point.getVector(), dictionary)); + } catch (IOException e) { + log.error("IOException: ", e); + } + vectorObjs.add(entry); + } + } + return vectorObjs; + } + + /** + * Convenience class for sorting terms + * + */ + private static class TermIndexWeight { + private final int index; + private final double weight; + + TermIndexWeight(int index, double weight) { + this.index = index; + this.weight = weight; + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/email/MailOptions.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/email/MailOptions.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/email/MailOptions.java new file mode 100644 index 0000000..54ad43f --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/email/MailOptions.java @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.email; + +import java.io.File; +import java.nio.charset.Charset; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Configuration options to be used by {@link MailProcessor}. Includes options controlling the exact output format + * and which mail fields are included (body, to, from, subject, etc.) + */ +public class MailOptions { + + public static final String FROM = "FROM"; + public static final String TO = "TO"; + public static final String REFS = "REFS"; + public static final String SUBJECT = "SUBJECT"; + public static final Pattern DEFAULT_QUOTED_TEXT = Pattern.compile("^(\\||>)"); + + private boolean stripQuotedText; + private File input; + private String outputDir; + private String prefix; + private int chunkSize; + private Charset charset; + private String separator; + private String bodySeparator = "\n"; + private boolean includeBody; + private Pattern[] patternsToMatch; + //maps FROM, TO, REFS, SUBJECT, etc. to the order they appear in patternsToMatch. See MailToRecMapper + private Map<String, Integer> patternOrder; + + //the regular expression to use for identifying quoted text. + private Pattern quotedTextPattern = DEFAULT_QUOTED_TEXT; + + public File getInput() { + return input; + } + + public void setInput(File input) { + this.input = input; + } + + public String getOutputDir() { + return outputDir; + } + + /** + * Sets the output directory where sequence files will be written. + */ + public void setOutputDir(String outputDir) { + this.outputDir = outputDir; + } + + public String getPrefix() { + return prefix; + } + + /** + * Sets the prefix that is combined with the archive name and with message ids to create {@code SequenceFile} keys. + * @param prefix The name of the directory containing the mail archive is commonly used. + */ + public void setPrefix(String prefix) { + this.prefix = prefix; + } + + public int getChunkSize() { + return chunkSize; + } + + /** + * Sets the size of each generated sequence file, in Megabytes. + */ + public void setChunkSize(int chunkSize) { + this.chunkSize = chunkSize; + } + + public Charset getCharset() { + return charset; + } + + /** + * Sets the encoding of the input + */ + public void setCharset(Charset charset) { + this.charset = charset; + } + + public String getSeparator() { + return separator; + } + + /** + * Sets the separator to use in the output between metadata items (to, from, etc.). + */ + public void setSeparator(String separator) { + this.separator = separator; + } + + public String getBodySeparator() { + return bodySeparator; + } + + /** + * Sets the separator to use in the output between lines in the body, the default is "\n". + */ + public void setBodySeparator(String bodySeparator) { + this.bodySeparator = bodySeparator; + } + + public boolean isIncludeBody() { + return includeBody; + } + + /** + * Sets whether mail bodies are included in the output + */ + public void setIncludeBody(boolean includeBody) { + this.includeBody = includeBody; + } + + public Pattern[] getPatternsToMatch() { + return patternsToMatch; + } + + /** + * Sets the list of patterns to be applied in the given order to extract metadata fields (to, from, subject, etc.) + * from the input + */ + public void setPatternsToMatch(Pattern[] patternsToMatch) { + this.patternsToMatch = patternsToMatch; + } + + public Map<String, Integer> getPatternOrder() { + return patternOrder; + } + + public void setPatternOrder(Map<String, Integer> patternOrder) { + this.patternOrder = patternOrder; + } + + /** + * + * @return true if we should strip out quoted email text + */ + public boolean isStripQuotedText() { + return stripQuotedText; + } + + /** + * + * Sets whether quoted text such as lines starting with | or > is striped off. + */ + public void setStripQuotedText(boolean stripQuotedText) { + this.stripQuotedText = stripQuotedText; + } + + public Pattern getQuotedTextPattern() { + return quotedTextPattern; + } + + /** + * Sets the {@link java.util.regex.Pattern} to use to identify lines that are quoted text. Default is | and > + * @see #setStripQuotedText(boolean) + */ + public void setQuotedTextPattern(Pattern quotedTextPattern) { + this.quotedTextPattern = quotedTextPattern; + } +}
