http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java new file mode 100644 index 0000000..180a1e1 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/ARFFVectorIterable.java @@ -0,0 +1,155 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.arff; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.nio.charset.Charset; +import java.text.DateFormat; +import java.text.SimpleDateFormat; +import java.util.Iterator; +import java.util.Locale; + +import com.google.common.io.Files; +import org.apache.commons.io.Charsets; +import org.apache.mahout.math.Vector; + +/** + * Read in ARFF (http://www.cs.waikato.ac.nz/~ml/weka/arff.html) and create {@link Vector}s + * <p/> + * Attribute type handling: + * <ul> + * <li>Numeric -> As is</li> + * <li>Nominal -> ordinal(value) i.e. @attribute lumber {'\'(-inf-0.5]\'','\'(0.5-inf)\''} + * will convert -inf-0.5 -> 0, and 0.5-inf -> 1</li> + * <li>Dates -> Convert to time as a long</li> + * <li>Strings -> Create a map of String -> long</li> + * </ul> + * NOTE: This class does not set the label bindings on every vector. If you want the label + * bindings, call {@link MapBackedARFFModel#getLabelBindings()}, as they are the same for every vector. + */ +public class ARFFVectorIterable implements Iterable<Vector> { + + private final BufferedReader buff; + private final ARFFModel model; + + public ARFFVectorIterable(File file, ARFFModel model) throws IOException { + this(file, Charsets.UTF_8, model); + } + + public ARFFVectorIterable(File file, Charset encoding, ARFFModel model) throws IOException { + this(Files.newReader(file, encoding), model); + } + + public ARFFVectorIterable(String arff, ARFFModel model) throws IOException { + this(new StringReader(arff), model); + } + + public ARFFVectorIterable(Reader reader, ARFFModel model) throws IOException { + if (reader instanceof BufferedReader) { + buff = (BufferedReader) reader; + } else { + buff = new BufferedReader(reader); + } + //grab the attributes, then start the iterator at the first line of data + this.model = model; + + int labelNumber = 0; + String line; + while ((line = buff.readLine()) != null) { + line = line.trim(); + if (!line.startsWith(ARFFModel.ARFF_COMMENT) && !line.isEmpty()) { + Integer labelNumInt = labelNumber; + String[] lineParts = line.split("[\\s\\t]+", 2); + + // is it a relation name? + if (lineParts[0].equalsIgnoreCase(ARFFModel.RELATION)) { + model.setRelation(ARFFType.removeQuotes(lineParts[1])); + } + // or an attribute + else if (lineParts[0].equalsIgnoreCase(ARFFModel.ATTRIBUTE)) { + String label; + ARFFType type; + + // split the name of the attribute and its description + String[] attrParts = lineParts[1].split("[\\s\\t]+", 2); + if (attrParts.length < 2) + throw new UnsupportedOperationException("No type for attribute found: " + lineParts[1]); + + // label is attribute name + label = ARFFType.removeQuotes(attrParts[0].toLowerCase()); + if (attrParts[1].equalsIgnoreCase(ARFFType.NUMERIC.getIndicator())) { + type = ARFFType.NUMERIC; + } else if (attrParts[1].equalsIgnoreCase(ARFFType.INTEGER.getIndicator())) { + type = ARFFType.INTEGER; + } else if (attrParts[1].equalsIgnoreCase(ARFFType.REAL.getIndicator())) { + type = ARFFType.REAL; + } else if (attrParts[1].equalsIgnoreCase(ARFFType.STRING.getIndicator())) { + type = ARFFType.STRING; + } else if (attrParts[1].toLowerCase().startsWith(ARFFType.NOMINAL.getIndicator())) { + type = ARFFType.NOMINAL; + // nominal example: + // @ATTRIBUTE class {Iris-setosa,'Iris versicolor',Iris-virginica} + String[] classes = ARFFIterator.splitCSV(attrParts[1].substring(1, attrParts[1].length() - 1)); + for (int i = 0; i < classes.length; i++) { + model.addNominal(label, ARFFType.removeQuotes(classes[i]), i + 1); + } + } else if (attrParts[1].toLowerCase().startsWith(ARFFType.DATE.getIndicator())) { + type = ARFFType.DATE; + //TODO: DateFormatter map + DateFormat format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.ENGLISH); + String formStr = attrParts[1].substring(ARFFType.DATE.getIndicator().length()).trim(); + if (!formStr.isEmpty()) { + if (formStr.startsWith("\"")) { + formStr = formStr.substring(1, formStr.length() - 1); + } + format = new SimpleDateFormat(formStr, Locale.ENGLISH); + } + model.addDateFormat(labelNumInt, format); + //@attribute <name> date [<date-format>] + } else { + throw new UnsupportedOperationException("Invalid attribute: " + attrParts[1]); + } + model.addLabel(label, labelNumInt); + model.addType(labelNumInt, type); + labelNumber++; + } else if (lineParts[0].equalsIgnoreCase(ARFFModel.DATA)) { + break; //skip it + } + } + } + + } + + @Override + public Iterator<Vector> iterator() { + return new ARFFIterator(buff, model); + } + + /** + * Returns info about the ARFF content that was parsed. + * + * @return the model + */ + public ARFFModel getModel() { + return model; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java new file mode 100644 index 0000000..ccecbb1 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/Driver.java @@ -0,0 +1,263 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * <p/> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p/> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.arff; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.io.Writer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import com.google.common.io.Files; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.utils.vectors.io.SequenceFileVectorWriter; +import org.apache.mahout.utils.vectors.io.VectorWriter; +import org.codehaus.jackson.map.ObjectMapper; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class Driver { + + private static final Logger log = LoggerFactory.getLogger(Driver.class); + + /** used for JSON serialization/deserialization */ + private static final ObjectMapper OBJECT_MAPPER = new ObjectMapper(); + + private Driver() { + } + + public static void main(String[] args) throws IOException { + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option inputOpt = obuilder + .withLongName("input") + .withRequired(true) + .withArgument(abuilder.withName("input").withMinimum(1).withMaximum(1).create()) + .withDescription( + "The file or directory containing the ARFF files. If it is a directory, all .arff files will be converted") + .withShortName("d").create(); + + Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument( + abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription( + "The output directory. Files will have the same name as the input, but with the extension .mvc") + .withShortName("o").create(); + + Option maxOpt = obuilder.withLongName("max").withRequired(false).withArgument( + abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription( + "The maximum number of vectors to output. If not specified, then it will loop over all docs") + .withShortName("m").create(); + + Option dictOutOpt = obuilder.withLongName("dictOut").withRequired(true).withArgument( + abuilder.withName("dictOut").withMinimum(1).withMaximum(1).create()).withDescription( + "The file to output the label bindings").withShortName("t").create(); + + Option jsonDictonaryOpt = obuilder.withLongName("json-dictonary").withRequired(false) + .withDescription("Write dictonary in JSON format").withShortName("j").create(); + + Option delimiterOpt = obuilder.withLongName("delimiter").withRequired(false).withArgument( + abuilder.withName("delimiter").withMinimum(1).withMaximum(1).create()).withDescription( + "The delimiter for outputing the dictionary").withShortName("l").create(); + + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(outputOpt).withOption(maxOpt) + .withOption(helpOpt).withOption(dictOutOpt).withOption(jsonDictonaryOpt).withOption(delimiterOpt) + .create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + + CommandLineUtil.printHelp(group); + return; + } + if (cmdLine.hasOption(inputOpt)) { // Lucene case + File input = new File(cmdLine.getValue(inputOpt).toString()); + long maxDocs = Long.MAX_VALUE; + if (cmdLine.hasOption(maxOpt)) { + maxDocs = Long.parseLong(cmdLine.getValue(maxOpt).toString()); + } + if (maxDocs < 0) { + throw new IllegalArgumentException("maxDocs must be >= 0"); + } + String outDir = cmdLine.getValue(outputOpt).toString(); + log.info("Output Dir: {}", outDir); + + String delimiter = cmdLine.hasOption(delimiterOpt) ? cmdLine.getValue(delimiterOpt).toString() : "\t"; + File dictOut = new File(cmdLine.getValue(dictOutOpt).toString()); + boolean jsonDictonary = cmdLine.hasOption(jsonDictonaryOpt); + ARFFModel model = new MapBackedARFFModel(); + if (input.exists() && input.isDirectory()) { + File[] files = input.listFiles(new FilenameFilter() { + @Override + public boolean accept(File file, String name) { + return name.endsWith(".arff"); + } + }); + + for (File file : files) { + writeFile(outDir, file, maxDocs, model, dictOut, delimiter, jsonDictonary); + } + } else { + writeFile(outDir, input, maxDocs, model, dictOut, delimiter, jsonDictonary); + } + } + + } catch (OptionException e) { + log.error("Exception", e); + CommandLineUtil.printHelp(group); + } + } + + protected static void writeLabelBindings(File dictOut, ARFFModel arffModel, String delimiter, boolean jsonDictonary) + throws IOException { + try (Writer writer = Files.newWriterSupplier(dictOut, Charsets.UTF_8, true).getOutput()) { + if (jsonDictonary) { + writeLabelBindingsJSON(writer, arffModel); + } else { + writeLabelBindings(writer, arffModel, delimiter); + } + } + } + + protected static void writeLabelBindingsJSON(Writer writer, ARFFModel arffModel) throws IOException { + + // Turn the map of labels into a list order by order of appearance + List<Entry<String, Integer>> attributes = new ArrayList<>(); + attributes.addAll(arffModel.getLabelBindings().entrySet()); + Collections.sort(attributes, new Comparator<Map.Entry<String, Integer>>() { + @Override + public int compare(Entry<String, Integer> t, Entry<String, Integer> t1) { + return t.getValue().compareTo(t1.getValue()); + } + }); + + // write a map for each object + List<Map<String, Object>> jsonObjects = new LinkedList<>(); + for (int i = 0; i < attributes.size(); i++) { + + Entry<String, Integer> modelRepresentation = attributes.get(i); + Map<String, Object> jsonRepresentation = new HashMap<>(); + jsonObjects.add(jsonRepresentation); + // the last one is the class label + jsonRepresentation.put("label", i < (attributes.size() - 1) ? String.valueOf(false) : String.valueOf(true)); + String attribute = modelRepresentation.getKey(); + jsonRepresentation.put("attribute", attribute); + Map<String, Integer> nominalValues = arffModel.getNominalMap().get(attribute); + + if (nominalValues != null) { + String[] values = nominalValues.keySet().toArray(new String[1]); + + jsonRepresentation.put("values", values); + jsonRepresentation.put("type", "categorical"); + } else { + jsonRepresentation.put("type", "numerical"); + } + } + writer.write(OBJECT_MAPPER.writeValueAsString(jsonObjects)); + } + + protected static void writeLabelBindings(Writer writer, ARFFModel arffModel, String delimiter) throws IOException { + + Map<String, Integer> labels = arffModel.getLabelBindings(); + writer.write("Label bindings for Relation " + arffModel.getRelation() + '\n'); + for (Map.Entry<String, Integer> entry : labels.entrySet()) { + writer.write(entry.getKey()); + writer.write(delimiter); + writer.write(String.valueOf(entry.getValue())); + writer.write('\n'); + } + writer.write('\n'); + writer.write("Values for nominal attributes\n"); + // emit allowed values for NOMINAL/categorical/enumerated attributes + Map<String, Map<String, Integer>> nominalMap = arffModel.getNominalMap(); + // how many nominal attributes + writer.write(String.valueOf(nominalMap.size()) + "\n"); + + for (Entry<String, Map<String, Integer>> entry : nominalMap.entrySet()) { + // the label of this attribute + writer.write(entry.getKey() + "\n"); + Set<Entry<String, Integer>> attributeValues = entry.getValue().entrySet(); + // how many values does this attribute have + writer.write(attributeValues.size() + "\n"); + for (Map.Entry<String, Integer> value : attributeValues) { + // the value and the value index + writer.write(String.format("%s%s%s\n", value.getKey(), delimiter, value.getValue().toString())); + } + } + } + + protected static void writeFile(String outDir, + File file, + long maxDocs, + ARFFModel arffModel, + File dictOut, + String delimiter, + boolean jsonDictonary) throws IOException { + log.info("Converting File: {}", file); + ARFFModel model = new MapBackedARFFModel(arffModel.getWords(), arffModel.getWordCount() + 1, arffModel + .getNominalMap()); + Iterable<Vector> iteratable = new ARFFVectorIterable(file, model); + String outFile = outDir + '/' + file.getName() + ".mvc"; + + try (VectorWriter vectorWriter = getSeqFileWriter(outFile)) { + long numDocs = vectorWriter.write(iteratable, maxDocs); + writeLabelBindings(dictOut, model, delimiter, jsonDictonary); + log.info("Wrote: {} vectors", numDocs); + } + } + + private static VectorWriter getSeqFileWriter(String outFile) throws IOException { + Path path = new Path(outFile); + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(conf); + SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class, + VectorWritable.class); + return new SequenceFileVectorWriter(seqWriter); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java new file mode 100644 index 0000000..e911b1a --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/arff/MapBackedARFFModel.java @@ -0,0 +1,282 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.arff; + +import java.text.DateFormat; +import java.text.NumberFormat; +import java.text.ParseException; +import java.text.ParsePosition; +import java.text.SimpleDateFormat; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * Holds ARFF information in {@link Map}. + */ +public class MapBackedARFFModel implements ARFFModel { + + private static final Pattern QUOTE_PATTERN = Pattern.compile("\""); + + private long wordCount = 1; + + private String relation; + + private final Map<String,Integer> labelBindings; + private final Map<Integer,String> idxLabel; + private final Map<Integer,ARFFType> typeMap; // key is the vector index, value is the type + private final Map<Integer,DateFormat> dateMap; + private final Map<String,Map<String,Integer>> nominalMap; + private final Map<String,Long> words; + + public MapBackedARFFModel() { + this(new HashMap<String,Long>(), 1, new HashMap<String,Map<String,Integer>>()); + } + + public MapBackedARFFModel(Map<String,Long> words, long wordCount, Map<String,Map<String,Integer>> nominalMap) { + this.words = words; + this.wordCount = wordCount; + labelBindings = new HashMap<>(); + idxLabel = new HashMap<>(); + typeMap = new HashMap<>(); + dateMap = new HashMap<>(); + this.nominalMap = nominalMap; + + } + + @Override + public String getRelation() { + return relation; + } + + @Override + public void setRelation(String relation) { + this.relation = relation; + } + + /** + * Convert a piece of String data at a specific spot into a value + * + * @param data + * The data to convert + * @param idx + * The position in the ARFF data + * @return A double representing the data + */ + @Override + public double getValue(String data, int idx) { + ARFFType type = typeMap.get(idx); + if (type == null) { + throw new IllegalArgumentException("Attribute type cannot be NULL, attribute index was: " + idx); + } + data = QUOTE_PATTERN.matcher(data).replaceAll(""); + data = data.trim(); + double result; + switch (type) { + case NUMERIC: + case INTEGER: + case REAL: + result = processNumeric(data); + break; + case DATE: + result = processDate(data, idx); + break; + case STRING: + // may have quotes + result = processString(data); + break; + case NOMINAL: + String label = idxLabel.get(idx); + result = processNominal(label, data); + break; + default: + throw new IllegalStateException("Unknown type: " + type); + } + return result; + } + + protected double processNominal(String label, String data) { + double result; + Map<String,Integer> classes = nominalMap.get(label); + if (classes != null) { + Integer ord = classes.get(ARFFType.removeQuotes(data)); + if (ord != null) { + result = ord; + } else { + throw new IllegalStateException("Invalid nominal: " + data + " for label: " + label); + } + } else { + throw new IllegalArgumentException("Invalid nominal label: " + label + " Data: " + data); + } + + return result; + } + + // Not sure how scalable this is going to be + protected double processString(String data) { + data = QUOTE_PATTERN.matcher(data).replaceAll(""); + // map it to an long + Long theLong = words.get(data); + if (theLong == null) { + theLong = wordCount++; + words.put(data, theLong); + } + return theLong; + } + + protected static double processNumeric(String data) { + if (isNumeric(data)) { + return Double.parseDouble(data); + } + return Double.NaN; + } + + public static boolean isNumeric(String str) { + NumberFormat formatter = NumberFormat.getInstance(); + ParsePosition parsePosition = new ParsePosition(0); + formatter.parse(str, parsePosition); + return str.length() == parsePosition.getIndex(); + } + + protected double processDate(String data, int idx) { + DateFormat format = dateMap.get(idx); + if (format == null) { + format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss", Locale.ENGLISH); + } + double result; + try { + Date date = format.parse(data); + result = date.getTime(); // hmmm, what kind of loss casting long to double? + } catch (ParseException e) { + throw new IllegalArgumentException(e); + } + return result; + } + + /** + * The vector attributes (labels in Mahout speak), unmodifiable + * + * @return the map + */ + @Override + public Map<String,Integer> getLabelBindings() { + return Collections.unmodifiableMap(labelBindings); + } + + /** + * The map of types encountered + * + * @return the map + */ + public Map<Integer,ARFFType> getTypeMap() { + return Collections.unmodifiableMap(typeMap); + } + + /** + * Map of Date formatters used + * + * @return the map + */ + public Map<Integer,DateFormat> getDateMap() { + return Collections.unmodifiableMap(dateMap); + } + + /** + * Map nominals to ids. Should only be modified by calling {@link ARFFModel#addNominal(String, String, int)} + * + * @return the map + */ + @Override + public Map<String,Map<String,Integer>> getNominalMap() { + return nominalMap; + } + + /** + * Immutable map of words to the long id used for those words + * + * @return The map + */ + @Override + public Map<String,Long> getWords() { + return words; + } + + @Override + public Integer getNominalValue(String label, String nominal) { + return nominalMap.get(label).get(nominal); + } + + @Override + public void addNominal(String label, String nominal, int idx) { + Map<String,Integer> noms = nominalMap.get(label); + if (noms == null) { + noms = new HashMap<>(); + nominalMap.put(label, noms); + } + noms.put(nominal, idx); + } + + @Override + public DateFormat getDateFormat(Integer idx) { + return dateMap.get(idx); + } + + @Override + public void addDateFormat(Integer idx, DateFormat format) { + dateMap.put(idx, format); + } + + @Override + public Integer getLabelIndex(String label) { + return labelBindings.get(label); + } + + @Override + public void addLabel(String label, Integer idx) { + labelBindings.put(label, idx); + idxLabel.put(idx, label); + } + + @Override + public ARFFType getARFFType(Integer idx) { + return typeMap.get(idx); + } + + @Override + public void addType(Integer idx, ARFFType type) { + typeMap.put(idx, type); + } + + /** + * The count of the number of words seen + * + * @return the count + */ + @Override + public long getWordCount() { + return wordCount; + } + + @Override + public int getLabelSize() { + return labelBindings.size(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java new file mode 100644 index 0000000..3c583fd --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/csv/CSVVectorIterator.java @@ -0,0 +1,69 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.csv; + +import java.io.IOException; +import java.io.Reader; + +import com.google.common.collect.AbstractIterator; +import org.apache.commons.csv.CSVParser; +import org.apache.commons.csv.CSVStrategy; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; + +/** + * Iterates a CSV file and produces {@link org.apache.mahout.math.Vector}. + * <br/> + * The Iterator returned throws {@link UnsupportedOperationException} for the {@link java.util.Iterator#remove()} + * method. + * <p/> + * Assumes DenseVector for now, but in the future may have the option of mapping columns to sparse format + * <p/> + * The Iterator is not thread-safe. + */ +public class CSVVectorIterator extends AbstractIterator<Vector> { + + private final CSVParser parser; + + public CSVVectorIterator(Reader reader) { + parser = new CSVParser(reader); + } + + public CSVVectorIterator(Reader reader, CSVStrategy strategy) { + parser = new CSVParser(reader, strategy); + } + + @Override + protected Vector computeNext() { + String[] line; + try { + line = parser.getLine(); + } catch (IOException e) { + throw new IllegalStateException(e); + } + if (line == null) { + return endOfData(); + } + Vector result = new DenseVector(line.length); + for (int i = 0; i < line.length; i++) { + result.setQuick(i, Double.parseDouble(line[i])); + } + return result; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java new file mode 100644 index 0000000..b5f9f2b --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/DelimitedTermInfoWriter.java @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.io; + +import java.io.IOException; +import java.io.Writer; +import java.util.Iterator; + +import com.google.common.io.Closeables; +import org.apache.mahout.utils.vectors.TermEntry; +import org.apache.mahout.utils.vectors.TermInfo; + +/** + * Write {@link TermInfo} to a {@link Writer} in a textual, delimited format with header. + */ +public class DelimitedTermInfoWriter implements TermInfoWriter { + + private final Writer writer; + private final String delimiter; + private final String field; + + public DelimitedTermInfoWriter(Writer writer, String delimiter, String field) { + this.writer = writer; + this.delimiter = delimiter; + this.field = field; + } + + @Override + public void write(TermInfo ti) throws IOException { + + Iterator<TermEntry> entIter = ti.getAllEntries(); + try { + writer.write(String.valueOf(ti.totalTerms(field))); + writer.write('\n'); + writer.write("#term" + delimiter + "doc freq" + delimiter + "idx"); + writer.write('\n'); + while (entIter.hasNext()) { + TermEntry entry = entIter.next(); + writer.write(entry.getTerm()); + writer.write(delimiter); + writer.write(String.valueOf(entry.getDocFreq())); + writer.write(delimiter); + writer.write(String.valueOf(entry.getTermIdx())); + writer.write('\n'); + } + } finally { + Closeables.close(writer, false); + } + } + + /** + * Does NOT close the underlying writer + */ + @Override + public void close() { + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java new file mode 100644 index 0000000..0d763a1 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/SequenceFileVectorWriter.java @@ -0,0 +1,75 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.io; + +import java.io.IOException; + +import com.google.common.io.Closeables; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + + +/** + * Writes out Vectors to a SequenceFile. + * + * Closes the writer when done + */ +public class SequenceFileVectorWriter implements VectorWriter { + private final SequenceFile.Writer writer; + private long recNum = 0; + public SequenceFileVectorWriter(SequenceFile.Writer writer) { + this.writer = writer; + } + + @Override + public long write(Iterable<Vector> iterable, long maxDocs) throws IOException { + + for (Vector point : iterable) { + if (recNum >= maxDocs) { + break; + } + if (point != null) { + writer.append(new LongWritable(recNum++), new VectorWritable(point)); + } + + } + return recNum; + } + + @Override + public void write(Vector vector) throws IOException { + writer.append(new LongWritable(recNum++), new VectorWritable(vector)); + + } + + @Override + public long write(Iterable<Vector> iterable) throws IOException { + return write(iterable, Long.MAX_VALUE); + } + + @Override + public void close() throws IOException { + Closeables.close(writer, false); + } + + public SequenceFile.Writer getWriter() { + return writer; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java new file mode 100644 index 0000000..e165b45 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TermInfoWriter.java @@ -0,0 +1,29 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.io; + +import java.io.Closeable; +import java.io.IOException; + +import org.apache.mahout.utils.vectors.TermInfo; + +public interface TermInfoWriter extends Closeable { + + void write(TermInfo ti) throws IOException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java new file mode 100644 index 0000000..cc27d1d --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/TextualVectorWriter.java @@ -0,0 +1,70 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.io; + +import java.io.IOException; +import java.io.Writer; + +import com.google.common.io.Closeables; +import org.apache.mahout.math.Vector; + +/** + * Write out the vectors to any {@link Writer} using {@link Vector#asFormatString()}, + * one per line by default. + */ +public class TextualVectorWriter implements VectorWriter { + + private final Writer writer; + + public TextualVectorWriter(Writer writer) { + this.writer = writer; + } + + protected Writer getWriter() { + return writer; + } + + @Override + public long write(Iterable<Vector> iterable) throws IOException { + return write(iterable, Long.MAX_VALUE); + } + + @Override + public long write(Iterable<Vector> iterable, long maxDocs) throws IOException { + long result = 0; + for (Vector vector : iterable) { + if (result >= maxDocs) { + break; + } + write(vector); + result++; + } + return result; + } + + @Override + public void write(Vector vector) throws IOException { + writer.write(vector.asFormatString()); + writer.write('\n'); + } + + @Override + public void close() throws IOException { + Closeables.close(writer, false); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java new file mode 100644 index 0000000..923e270 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/io/VectorWriter.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.io; + +import java.io.Closeable; +import java.io.IOException; + +import org.apache.mahout.math.Vector; + +public interface VectorWriter extends Closeable { + /** + * Write all values in the Iterable to the output + * @param iterable The {@link Iterable} to loop over + * @return the number of docs written + * @throws IOException if there was a problem writing + * + */ + long write(Iterable<Vector> iterable) throws IOException; + + /** + * Write out a vector + * + * @param vector The {@link org.apache.mahout.math.Vector} to write + * @throws IOException + */ + void write(Vector vector) throws IOException; + + /** + * Write the first {@code maxDocs} to the output. + * @param iterable The {@link Iterable} to loop over + * @param maxDocs the maximum number of docs to write + * @return The number of docs written + * @throws IOException if there was a problem writing + */ + long write(Iterable<Vector> iterable, long maxDocs) throws IOException; + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java new file mode 100644 index 0000000..ff61a70 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/AbstractLuceneIterator.java @@ -0,0 +1,140 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.lucene; + +import com.google.common.collect.AbstractIterator; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.util.BytesRef; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.utils.Bump125; +import org.apache.mahout.utils.vectors.TermInfo; +import org.apache.mahout.vectorizer.Weight; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; + +/** + * Iterate over a Lucene index, extracting term vectors. + * Subclasses define how much information to retrieve from the Lucene index. + */ +public abstract class AbstractLuceneIterator extends AbstractIterator<Vector> { + private static final Logger log = LoggerFactory.getLogger(LuceneIterator.class); + protected final IndexReader indexReader; + protected final String field; + protected final TermInfo terminfo; + protected final double normPower; + protected final Weight weight; + protected final Bump125 bump = new Bump125(); + protected int nextDocId; + protected int maxErrorDocs; + protected int numErrorDocs; + protected long nextLogRecord = bump.increment(); + protected int skippedErrorMessages; + + public AbstractLuceneIterator(TermInfo terminfo, double normPower, IndexReader indexReader, Weight weight, + double maxPercentErrorDocs, String field) { + this.terminfo = terminfo; + this.normPower = normPower; + this.indexReader = indexReader; + + this.weight = weight; + this.nextDocId = 0; + this.maxErrorDocs = (int) (maxPercentErrorDocs * indexReader.numDocs()); + this.field = field; + } + + /** + * Given the document name, derive a name for the vector. This may involve + * reading the document from Lucene and setting up any other state that the + * subclass wants. This will be called once for each document that the + * iterator processes. + * @param documentIndex the lucene document index. + * @return the name to store in the vector. + */ + protected abstract String getVectorName(int documentIndex) throws IOException; + + @Override + protected Vector computeNext() { + try { + int doc; + Terms termFreqVector; + String name; + + do { + doc = this.nextDocId; + nextDocId++; + + if (doc >= indexReader.maxDoc()) { + return endOfData(); + } + + termFreqVector = indexReader.getTermVector(doc, field); + name = getVectorName(doc); + + if (termFreqVector == null) { + numErrorDocs++; + if (numErrorDocs >= maxErrorDocs) { + log.error("There are too many documents that do not have a term vector for {}", field); + throw new IllegalStateException("There are too many documents that do not have a term vector for " + + field); + } + if (numErrorDocs >= nextLogRecord) { + if (skippedErrorMessages == 0) { + log.warn("{} does not have a term vector for {}", name, field); + } else { + log.warn("{} documents do not have a term vector for {}", numErrorDocs, field); + } + nextLogRecord = bump.increment(); + skippedErrorMessages = 0; + } else { + skippedErrorMessages++; + } + } + } while (termFreqVector == null); + + // The loop exits with termFreqVector and name set. + + TermsEnum te = termFreqVector.iterator(); + BytesRef term; + TFDFMapper mapper = new TFDFMapper(indexReader.numDocs(), weight, this.terminfo); + mapper.setExpectations(field, termFreqVector.size()); + while ((term = te.next()) != null) { + mapper.map(term, (int) te.totalTermFreq()); + } + Vector result = mapper.getVector(); + if (result == null) { + // TODO is this right? last version would produce null in the iteration in this case, though it + // seems like that may not be desirable + return null; + } + + if (normPower == LuceneIterable.NO_NORMALIZING) { + result = new NamedVector(result, name); + } else { + result = new NamedVector(result.normalize(normPower), name); + } + return result; + } catch (IOException ioe) { + throw new IllegalStateException(ioe); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java new file mode 100644 index 0000000..0b59ed6 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/CachedTermInfo.java @@ -0,0 +1,79 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.lucene; + +import java.io.IOException; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; + +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.util.BytesRef; +import org.apache.mahout.utils.vectors.TermEntry; +import org.apache.mahout.utils.vectors.TermInfo; + + +/** + * Caches TermEntries from a single field. Materializes all values in the TermEnum to memory (much like FieldCache) + */ +public class CachedTermInfo implements TermInfo { + + private final Map<String, TermEntry> termEntries; + private final String field; + + public CachedTermInfo(IndexReader reader, String field, int minDf, int maxDfPercent) throws IOException { + this.field = field; + Terms t = MultiFields.getTerms(reader, field); + TermsEnum te = t.iterator(); + + int numDocs = reader.numDocs(); + double percent = numDocs * maxDfPercent / 100.0; + //Should we use a linked hash map so that we know terms are in order? + termEntries = new LinkedHashMap<>(); + int count = 0; + BytesRef text; + while ((text = te.next()) != null) { + int df = te.docFreq(); + if (df >= minDf && df <= percent) { + TermEntry entry = new TermEntry(text.utf8ToString(), count++, df); + termEntries.put(entry.getTerm(), entry); + } + } + } + + @Override + public int totalTerms(String field) { + return termEntries.size(); + } + + @Override + public TermEntry getTermEntry(String field, String term) { + if (!this.field.equals(field)) { + return null; + } + return termEntries.get(term); + } + + @Override + public Iterator<TermEntry> getAllEntries() { + return termEntries.values().iterator(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java new file mode 100644 index 0000000..b2568e7 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/ClusterLabels.java @@ -0,0 +1,381 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.lucene; + +import java.io.File; +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.Writer; +import java.nio.file.Paths; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashMap; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; + +import com.google.common.io.Closeables; +import com.google.common.io.Files; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.fs.Path; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.index.MultiFields; +import org.apache.lucene.index.PostingsEnum; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.Terms; +import org.apache.lucene.index.TermsEnum; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.Bits; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.FixedBitSet; +import org.apache.mahout.clustering.classify.WeightedPropertyVectorWritable; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.stats.LogLikelihood; +import org.apache.mahout.utils.clustering.ClusterDumper; +import org.apache.mahout.utils.vectors.TermEntry; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Get labels for the cluster using Log Likelihood Ratio (LLR). + * <p/> + *"The most useful way to think of this (LLR) is as the percentage of in-cluster documents that have the + * feature (term) versus the percentage out, keeping in mind that both percentages are uncertain since we have + * only a sample of all possible documents." - Ted Dunning + * <p/> + * More about LLR can be found at : http://tdunning.blogspot.com/2008/03/surprise-and-coincidence.html + */ +public class ClusterLabels { + + private static final Logger log = LoggerFactory.getLogger(ClusterLabels.class); + + public static final int DEFAULT_MIN_IDS = 50; + public static final int DEFAULT_MAX_LABELS = 25; + + private final String indexDir; + private final String contentField; + private String idField; + private final Map<Integer, List<WeightedPropertyVectorWritable>> clusterIdToPoints; + private String output; + private final int minNumIds; + private final int maxLabels; + + public ClusterLabels(Path seqFileDir, + Path pointsDir, + String indexDir, + String contentField, + int minNumIds, + int maxLabels) { + this.indexDir = indexDir; + this.contentField = contentField; + this.minNumIds = minNumIds; + this.maxLabels = maxLabels; + ClusterDumper clusterDumper = new ClusterDumper(seqFileDir, pointsDir); + this.clusterIdToPoints = clusterDumper.getClusterIdToPoints(); + } + + public void getLabels() throws IOException { + + try (Writer writer = (this.output == null) ? + new OutputStreamWriter(System.out, Charsets.UTF_8) : Files.newWriter(new File(this.output), Charsets.UTF_8)){ + for (Map.Entry<Integer, List<WeightedPropertyVectorWritable>> integerListEntry : clusterIdToPoints.entrySet()) { + List<WeightedPropertyVectorWritable> wpvws = integerListEntry.getValue(); + List<TermInfoClusterInOut> termInfos = getClusterLabels(integerListEntry.getKey(), wpvws); + if (termInfos != null) { + writer.write('\n'); + writer.write("Top labels for Cluster "); + writer.write(String.valueOf(integerListEntry.getKey())); + writer.write(" containing "); + writer.write(String.valueOf(wpvws.size())); + writer.write(" vectors"); + writer.write('\n'); + writer.write("Term \t\t LLR \t\t In-ClusterDF \t\t Out-ClusterDF "); + writer.write('\n'); + for (TermInfoClusterInOut termInfo : termInfos) { + writer.write(termInfo.getTerm()); + writer.write("\t\t"); + writer.write(String.valueOf(termInfo.getLogLikelihoodRatio())); + writer.write("\t\t"); + writer.write(String.valueOf(termInfo.getInClusterDF())); + writer.write("\t\t"); + writer.write(String.valueOf(termInfo.getOutClusterDF())); + writer.write('\n'); + } + } + } + } + } + + /** + * Get the list of labels, sorted by best score. + */ + protected List<TermInfoClusterInOut> getClusterLabels(Integer integer, + Collection<WeightedPropertyVectorWritable> wpvws) throws IOException { + + if (wpvws.size() < minNumIds) { + log.info("Skipping small cluster {} with size: {}", integer, wpvws.size()); + return null; + } + + log.info("Processing Cluster {} with {} documents", integer, wpvws.size()); + Directory dir = FSDirectory.open(Paths.get(this.indexDir)); + IndexReader reader = DirectoryReader.open(dir); + + + log.info("# of documents in the index {}", reader.numDocs()); + + Collection<String> idSet = new HashSet<>(); + for (WeightedPropertyVectorWritable wpvw : wpvws) { + Vector vector = wpvw.getVector(); + if (vector instanceof NamedVector) { + idSet.add(((NamedVector) vector).getName()); + } + } + + int numDocs = reader.numDocs(); + + FixedBitSet clusterDocBitset = getClusterDocBitset(reader, idSet, this.idField); + + log.info("Populating term infos from the index"); + + /** + * This code is as that of CachedTermInfo, with one major change, which is to get the document frequency. + * + * Since we have deleted the documents out of the cluster, the document frequency for a term should only + * include the in-cluster documents. The document frequency obtained from TermEnum reflects the frequency + * in the entire index. To get the in-cluster frequency, we need to query the index to get the term + * frequencies in each document. The number of results of this call will be the in-cluster document + * frequency. + */ + Terms t = MultiFields.getTerms(reader, contentField); + TermsEnum te = t.iterator(); + Map<String, TermEntry> termEntryMap = new LinkedHashMap<>(); + Bits liveDocs = MultiFields.getLiveDocs(reader); //WARNING: returns null if there are no deletions + + + int count = 0; + BytesRef term; + while ((term = te.next()) != null) { + FixedBitSet termBitset = new FixedBitSet(reader.maxDoc()); + PostingsEnum docsEnum = MultiFields.getTermDocsEnum(reader, contentField, term); + int docID; + while ((docID = docsEnum.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) { + //check to see if we don't have an deletions (null) or if document is live + if (liveDocs != null && !liveDocs.get(docID)) { + // document is deleted... + termBitset.set(docsEnum.docID()); + } + } + // AND the term's bitset with cluster doc bitset to get the term's in-cluster frequency. + // This modifies the termBitset, but that's fine as we are not using it anywhere else. + termBitset.and(clusterDocBitset); + int inclusterDF = (int) termBitset.cardinality(); + + TermEntry entry = new TermEntry(term.utf8ToString(), count++, inclusterDF); + termEntryMap.put(entry.getTerm(), entry); + + } + + List<TermInfoClusterInOut> clusteredTermInfo = new LinkedList<>(); + + int clusterSize = wpvws.size(); + + for (TermEntry termEntry : termEntryMap.values()) { + + int corpusDF = reader.docFreq(new Term(this.contentField,termEntry.getTerm())); + int outDF = corpusDF - termEntry.getDocFreq(); + int inDF = termEntry.getDocFreq(); + double logLikelihoodRatio = scoreDocumentFrequencies(inDF, outDF, clusterSize, numDocs); + TermInfoClusterInOut termInfoCluster = + new TermInfoClusterInOut(termEntry.getTerm(), inDF, outDF, logLikelihoodRatio); + clusteredTermInfo.add(termInfoCluster); + } + + Collections.sort(clusteredTermInfo); + // Cleanup + Closeables.close(reader, true); + termEntryMap.clear(); + + return clusteredTermInfo.subList(0, Math.min(clusteredTermInfo.size(), maxLabels)); + } + + private static FixedBitSet getClusterDocBitset(IndexReader reader, + Collection<String> idSet, + String idField) throws IOException { + int numDocs = reader.numDocs(); + + FixedBitSet bitset = new FixedBitSet(numDocs); + + Set<String> idFieldSelector = null; + if (idField != null) { + idFieldSelector = new TreeSet<>(); + idFieldSelector.add(idField); + } + + + for (int i = 0; i < numDocs; i++) { + String id; + // Use Lucene's internal ID if idField is not specified. Else, get it from the document. + if (idField == null) { + id = Integer.toString(i); + } else { + id = reader.document(i, idFieldSelector).get(idField); + } + if (idSet.contains(id)) { + bitset.set(i); + } + } + log.info("Created bitset for in-cluster documents : {}", bitset.cardinality()); + return bitset; + } + + private static double scoreDocumentFrequencies(long inDF, long outDF, long clusterSize, long corpusSize) { + long k12 = clusterSize - inDF; + long k22 = corpusSize - clusterSize - outDF; + + return LogLikelihood.logLikelihoodRatio(inDF, k12, outDF, k22); + } + + public String getIdField() { + return idField; + } + + public void setIdField(String idField) { + this.idField = idField; + } + + public String getOutput() { + return output; + } + + public void setOutput(String output) { + this.output = output; + } + + public static void main(String[] args) { + + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option indexOpt = obuilder.withLongName("dir").withRequired(true).withArgument( + abuilder.withName("dir").withMinimum(1).withMaximum(1).create()) + .withDescription("The Lucene index directory").withShortName("d").create(); + + Option outputOpt = obuilder.withLongName("output").withRequired(false).withArgument( + abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription( + "The output file. If not specified, the result is printed on console.").withShortName("o").create(); + + Option fieldOpt = obuilder.withLongName("field").withRequired(true).withArgument( + abuilder.withName("field").withMinimum(1).withMaximum(1).create()) + .withDescription("The content field in the index").withShortName("f").create(); + + Option idFieldOpt = obuilder.withLongName("idField").withRequired(false).withArgument( + abuilder.withName("idField").withMinimum(1).withMaximum(1).create()).withDescription( + "The field for the document ID in the index. If null, then the Lucene internal doc " + + "id is used which is prone to error if the underlying index changes").withShortName("i").create(); + + Option seqOpt = obuilder.withLongName("seqFileDir").withRequired(true).withArgument( + abuilder.withName("seqFileDir").withMinimum(1).withMaximum(1).create()).withDescription( + "The directory containing Sequence Files for the Clusters").withShortName("s").create(); + + Option pointsOpt = obuilder.withLongName("pointsDir").withRequired(true).withArgument( + abuilder.withName("pointsDir").withMinimum(1).withMaximum(1).create()).withDescription( + "The directory containing points sequence files mapping input vectors to their cluster. ") + .withShortName("p").create(); + Option minClusterSizeOpt = obuilder.withLongName("minClusterSize").withRequired(false).withArgument( + abuilder.withName("minClusterSize").withMinimum(1).withMaximum(1).create()).withDescription( + "The minimum number of points required in a cluster to print the labels for").withShortName("m").create(); + Option maxLabelsOpt = obuilder.withLongName("maxLabels").withRequired(false).withArgument( + abuilder.withName("maxLabels").withMinimum(1).withMaximum(1).create()).withDescription( + "The maximum number of labels to print per cluster").withShortName("x").create(); + Option helpOpt = DefaultOptionCreator.helpOption(); + + Group group = gbuilder.withName("Options").withOption(indexOpt).withOption(idFieldOpt).withOption(outputOpt) + .withOption(fieldOpt).withOption(seqOpt).withOption(pointsOpt).withOption(helpOpt) + .withOption(maxLabelsOpt).withOption(minClusterSizeOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + CommandLineUtil.printHelp(group); + return; + } + + Path seqFileDir = new Path(cmdLine.getValue(seqOpt).toString()); + Path pointsDir = new Path(cmdLine.getValue(pointsOpt).toString()); + String indexDir = cmdLine.getValue(indexOpt).toString(); + String contentField = cmdLine.getValue(fieldOpt).toString(); + + String idField = null; + + if (cmdLine.hasOption(idFieldOpt)) { + idField = cmdLine.getValue(idFieldOpt).toString(); + } + String output = null; + if (cmdLine.hasOption(outputOpt)) { + output = cmdLine.getValue(outputOpt).toString(); + } + int maxLabels = DEFAULT_MAX_LABELS; + if (cmdLine.hasOption(maxLabelsOpt)) { + maxLabels = Integer.parseInt(cmdLine.getValue(maxLabelsOpt).toString()); + } + int minSize = DEFAULT_MIN_IDS; + if (cmdLine.hasOption(minClusterSizeOpt)) { + minSize = Integer.parseInt(cmdLine.getValue(minClusterSizeOpt).toString()); + } + ClusterLabels clusterLabel = new ClusterLabels(seqFileDir, pointsDir, indexDir, contentField, minSize, maxLabels); + + if (idField != null) { + clusterLabel.setIdField(idField); + } + if (output != null) { + clusterLabel.setOutput(output); + } + + clusterLabel.getLabels(); + + } catch (OptionException e) { + log.error("Exception", e); + CommandLineUtil.printHelp(group); + } catch (IOException e) { + log.error("Exception", e); + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java new file mode 100644 index 0000000..876816f --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/Driver.java @@ -0,0 +1,349 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * <p/> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p/> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.lucene; + +import java.io.File; +import java.io.IOException; +import java.io.Writer; +import java.nio.file.Paths; +import java.util.Iterator; + +import com.google.common.base.Preconditions; +import com.google.common.io.Files; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.OptionException; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexReader; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.FSDirectory; +import org.apache.mahout.common.CommandLineUtil; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.utils.vectors.TermEntry; +import org.apache.mahout.utils.vectors.TermInfo; +import org.apache.mahout.utils.vectors.io.DelimitedTermInfoWriter; +import org.apache.mahout.utils.vectors.io.SequenceFileVectorWriter; +import org.apache.mahout.utils.vectors.io.VectorWriter; +import org.apache.mahout.vectorizer.TF; +import org.apache.mahout.vectorizer.TFIDF; +import org.apache.mahout.vectorizer.Weight; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class Driver { + + private static final Logger log = LoggerFactory.getLogger(Driver.class); + + private String luceneDir; + private String outFile; + private String field; + private String idField; + private String dictOut; + private String seqDictOut = ""; + private String weightType = "tfidf"; + private String delimiter = "\t"; + private double norm = LuceneIterable.NO_NORMALIZING; + private long maxDocs = Long.MAX_VALUE; + private int minDf = 1; + private int maxDFPercent = 99; + private double maxPercentErrorDocs = 0.0; + + public void dumpVectors() throws IOException { + + File file = new File(luceneDir); + Preconditions.checkArgument(file.isDirectory(), + "Lucene directory: " + file.getAbsolutePath() + + " does not exist or is not a directory"); + Preconditions.checkArgument(maxDocs >= 0, "maxDocs must be >= 0"); + Preconditions.checkArgument(minDf >= 1, "minDf must be >= 1"); + Preconditions.checkArgument(maxDFPercent <= 99, "maxDFPercent must be <= 99"); + + Directory dir = FSDirectory.open(Paths.get(file.getAbsolutePath())); + IndexReader reader = DirectoryReader.open(dir); + + + Weight weight; + if ("tf".equalsIgnoreCase(weightType)) { + weight = new TF(); + } else if ("tfidf".equalsIgnoreCase(weightType)) { + weight = new TFIDF(); + } else { + throw new IllegalArgumentException("Weight type " + weightType + " is not supported"); + } + + TermInfo termInfo = new CachedTermInfo(reader, field, minDf, maxDFPercent); + + LuceneIterable iterable; + if (norm == LuceneIterable.NO_NORMALIZING) { + iterable = new LuceneIterable(reader, idField, field, termInfo, weight, LuceneIterable.NO_NORMALIZING, + maxPercentErrorDocs); + } else { + iterable = new LuceneIterable(reader, idField, field, termInfo, weight, norm, maxPercentErrorDocs); + } + + log.info("Output File: {}", outFile); + + try (VectorWriter vectorWriter = getSeqFileWriter(outFile)) { + long numDocs = vectorWriter.write(iterable, maxDocs); + log.info("Wrote: {} vectors", numDocs); + } + + File dictOutFile = new File(dictOut); + log.info("Dictionary Output file: {}", dictOutFile); + Writer writer = Files.newWriter(dictOutFile, Charsets.UTF_8); + try (DelimitedTermInfoWriter tiWriter = new DelimitedTermInfoWriter(writer, delimiter, field)) { + tiWriter.write(termInfo); + } + + if (!"".equals(seqDictOut)) { + log.info("SequenceFile Dictionary Output file: {}", seqDictOut); + + Path path = new Path(seqDictOut); + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(conf); + try (SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, Text.class, IntWritable.class)) { + Text term = new Text(); + IntWritable termIndex = new IntWritable(); + Iterator<TermEntry> termEntries = termInfo.getAllEntries(); + while (termEntries.hasNext()) { + TermEntry termEntry = termEntries.next(); + term.set(termEntry.getTerm()); + termIndex.set(termEntry.getTermIdx()); + seqWriter.append(term, termIndex); + } + } + } + } + + public static void main(String[] args) throws IOException { + + DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); + ArgumentBuilder abuilder = new ArgumentBuilder(); + GroupBuilder gbuilder = new GroupBuilder(); + + Option inputOpt = obuilder.withLongName("dir").withRequired(true).withArgument( + abuilder.withName("dir").withMinimum(1).withMaximum(1).create()) + .withDescription("The Lucene directory").withShortName("d").create(); + + Option outputOpt = obuilder.withLongName("output").withRequired(true).withArgument( + abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription("The output file") + .withShortName("o").create(); + + Option fieldOpt = obuilder.withLongName("field").withRequired(true).withArgument( + abuilder.withName("field").withMinimum(1).withMaximum(1).create()).withDescription( + "The field in the index").withShortName("f").create(); + + Option idFieldOpt = obuilder.withLongName("idField").withRequired(false).withArgument( + abuilder.withName("idField").withMinimum(1).withMaximum(1).create()).withDescription( + "The field in the index containing the index. If null, then the Lucene internal doc " + + "id is used which is prone to error if the underlying index changes").create(); + + Option dictOutOpt = obuilder.withLongName("dictOut").withRequired(true).withArgument( + abuilder.withName("dictOut").withMinimum(1).withMaximum(1).create()).withDescription( + "The output of the dictionary").withShortName("t").create(); + + Option seqDictOutOpt = obuilder.withLongName("seqDictOut").withRequired(false).withArgument( + abuilder.withName("seqDictOut").withMinimum(1).withMaximum(1).create()).withDescription( + "The output of the dictionary as sequence file").withShortName("st").create(); + + Option weightOpt = obuilder.withLongName("weight").withRequired(false).withArgument( + abuilder.withName("weight").withMinimum(1).withMaximum(1).create()).withDescription( + "The kind of weight to use. Currently TF or TFIDF").withShortName("w").create(); + + Option delimiterOpt = obuilder.withLongName("delimiter").withRequired(false).withArgument( + abuilder.withName("delimiter").withMinimum(1).withMaximum(1).create()).withDescription( + "The delimiter for outputting the dictionary").withShortName("l").create(); + + Option powerOpt = obuilder.withLongName("norm").withRequired(false).withArgument( + abuilder.withName("norm").withMinimum(1).withMaximum(1).create()).withDescription( + "The norm to use, expressed as either a double or \"INF\" if you want to use the Infinite norm. " + + "Must be greater or equal to 0. The default is not to normalize").withShortName("n").create(); + + Option maxOpt = obuilder.withLongName("max").withRequired(false).withArgument( + abuilder.withName("max").withMinimum(1).withMaximum(1).create()).withDescription( + "The maximum number of vectors to output. If not specified, then it will loop over all docs") + .withShortName("m").create(); + + Option minDFOpt = obuilder.withLongName("minDF").withRequired(false).withArgument( + abuilder.withName("minDF").withMinimum(1).withMaximum(1).create()).withDescription( + "The minimum document frequency. Default is 1").withShortName("md").create(); + + Option maxDFPercentOpt = obuilder.withLongName("maxDFPercent").withRequired(false).withArgument( + abuilder.withName("maxDFPercent").withMinimum(1).withMaximum(1).create()).withDescription( + "The max percentage of docs for the DF. Can be used to remove really high frequency terms." + + " Expressed as an integer between 0 and 100. Default is 99.").withShortName("x").create(); + + Option maxPercentErrorDocsOpt = obuilder.withLongName("maxPercentErrorDocs").withRequired(false).withArgument( + abuilder.withName("maxPercentErrorDocs").withMinimum(1).withMaximum(1).create()).withDescription( + "The max percentage of docs that can have a null term vector. These are noise document and can occur if the " + + "analyzer used strips out all terms in the target field. This percentage is expressed as a value " + + "between 0 and 1. The default is 0.").withShortName("err").create(); + + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") + .create(); + + Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(idFieldOpt).withOption( + outputOpt).withOption(delimiterOpt).withOption(helpOpt).withOption(fieldOpt).withOption(maxOpt) + .withOption(dictOutOpt).withOption(seqDictOutOpt).withOption(powerOpt).withOption(maxDFPercentOpt) + .withOption(weightOpt).withOption(minDFOpt).withOption(maxPercentErrorDocsOpt).create(); + + try { + Parser parser = new Parser(); + parser.setGroup(group); + CommandLine cmdLine = parser.parse(args); + + if (cmdLine.hasOption(helpOpt)) { + + CommandLineUtil.printHelp(group); + return; + } + + if (cmdLine.hasOption(inputOpt)) { // Lucene case + Driver luceneDriver = new Driver(); + luceneDriver.setLuceneDir(cmdLine.getValue(inputOpt).toString()); + + if (cmdLine.hasOption(maxOpt)) { + luceneDriver.setMaxDocs(Long.parseLong(cmdLine.getValue(maxOpt).toString())); + } + + if (cmdLine.hasOption(weightOpt)) { + luceneDriver.setWeightType(cmdLine.getValue(weightOpt).toString()); + } + + luceneDriver.setField(cmdLine.getValue(fieldOpt).toString()); + + if (cmdLine.hasOption(minDFOpt)) { + luceneDriver.setMinDf(Integer.parseInt(cmdLine.getValue(minDFOpt).toString())); + } + + if (cmdLine.hasOption(maxDFPercentOpt)) { + luceneDriver.setMaxDFPercent(Integer.parseInt(cmdLine.getValue(maxDFPercentOpt).toString())); + } + + if (cmdLine.hasOption(powerOpt)) { + String power = cmdLine.getValue(powerOpt).toString(); + if ("INF".equals(power)) { + luceneDriver.setNorm(Double.POSITIVE_INFINITY); + } else { + luceneDriver.setNorm(Double.parseDouble(power)); + } + } + + if (cmdLine.hasOption(idFieldOpt)) { + luceneDriver.setIdField(cmdLine.getValue(idFieldOpt).toString()); + } + + if (cmdLine.hasOption(maxPercentErrorDocsOpt)) { + luceneDriver.setMaxPercentErrorDocs(Double.parseDouble(cmdLine.getValue(maxPercentErrorDocsOpt).toString())); + } + + luceneDriver.setOutFile(cmdLine.getValue(outputOpt).toString()); + + luceneDriver.setDelimiter(cmdLine.hasOption(delimiterOpt) ? cmdLine.getValue(delimiterOpt).toString() : "\t"); + + luceneDriver.setDictOut(cmdLine.getValue(dictOutOpt).toString()); + + if (cmdLine.hasOption(seqDictOutOpt)) { + luceneDriver.setSeqDictOut(cmdLine.getValue(seqDictOutOpt).toString()); + } + + luceneDriver.dumpVectors(); + } + } catch (OptionException e) { + log.error("Exception", e); + CommandLineUtil.printHelp(group); + } + } + + private static VectorWriter getSeqFileWriter(String outFile) throws IOException { + Path path = new Path(outFile); + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(conf); + // TODO: Make this parameter driven + + SequenceFile.Writer seqWriter = SequenceFile.createWriter(fs, conf, path, LongWritable.class, + VectorWritable.class); + + return new SequenceFileVectorWriter(seqWriter); + } + + public void setLuceneDir(String luceneDir) { + this.luceneDir = luceneDir; + } + + public void setMaxDocs(long maxDocs) { + this.maxDocs = maxDocs; + } + + public void setWeightType(String weightType) { + this.weightType = weightType; + } + + public void setField(String field) { + this.field = field; + } + + public void setMinDf(int minDf) { + this.minDf = minDf; + } + + public void setMaxDFPercent(int maxDFPercent) { + this.maxDFPercent = maxDFPercent; + } + + public void setNorm(double norm) { + this.norm = norm; + } + + public void setIdField(String idField) { + this.idField = idField; + } + + public void setOutFile(String outFile) { + this.outFile = outFile; + } + + public void setDelimiter(String delimiter) { + this.delimiter = delimiter; + } + + public void setDictOut(String dictOut) { + this.dictOut = dictOut; + } + + public void setSeqDictOut(String seqDictOut) { + this.seqDictOut = seqDictOut; + } + + public void setMaxPercentErrorDocs(double maxPercentErrorDocs) { + this.maxPercentErrorDocs = maxPercentErrorDocs; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/e0573de3/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java new file mode 100644 index 0000000..1af0ed0 --- /dev/null +++ b/community/mahout-mr/integration/src/main/java/org/apache/mahout/utils/vectors/lucene/LuceneIterable.java @@ -0,0 +1,80 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.utils.vectors.lucene; + +import org.apache.lucene.index.IndexReader; +import org.apache.mahout.math.Vector; +import org.apache.mahout.utils.vectors.TermInfo; +import org.apache.mahout.vectorizer.Weight; + +import java.util.Iterator; + +/** + * {@link Iterable} counterpart to {@link LuceneIterator}. + */ +public final class LuceneIterable implements Iterable<Vector> { + + public static final double NO_NORMALIZING = -1.0; + + private final IndexReader indexReader; + private final String field; + private final String idField; + private final TermInfo terminfo; + private final double normPower; + private final double maxPercentErrorDocs; + private final Weight weight; + + public LuceneIterable(IndexReader reader, String idField, String field, TermInfo terminfo, Weight weight) { + this(reader, idField, field, terminfo, weight, NO_NORMALIZING); + } + + public LuceneIterable(IndexReader indexReader, String idField, String field, TermInfo terminfo, Weight weight, + double normPower) { + this(indexReader, idField, field, terminfo, weight, normPower, 0); + } + + /** + * Produce a LuceneIterable that can create the Vector plus normalize it. + * + * @param indexReader {@link org.apache.lucene.index.IndexReader} to read the documents from. + * @param idField field containing the id. May be null. + * @param field field to use for the Vector + * @param normPower the normalization value. Must be nonnegative, or {@link #NO_NORMALIZING} + * @param maxPercentErrorDocs the percentage of documents in the lucene index that can have a null term vector + */ + public LuceneIterable(IndexReader indexReader, + String idField, + String field, + TermInfo terminfo, + Weight weight, + double normPower, + double maxPercentErrorDocs) { + this.indexReader = indexReader; + this.idField = idField; + this.field = field; + this.terminfo = terminfo; + this.normPower = normPower; + this.maxPercentErrorDocs = maxPercentErrorDocs; + this.weight = weight; + } + + @Override + public Iterator<Vector> iterator() { + return new LuceneIterator(indexReader, idField, field, terminfo, weight, normPower, maxPercentErrorDocs); + } +}
