http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorMapper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorMapper.java b/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorMapper.java deleted file mode 100644 index 50e5f37..0000000 --- a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorMapper.java +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.text.wikipedia; - -import com.google.common.io.Closeables; -import org.apache.commons.lang3.StringEscapeUtils; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.io.DefaultStringifier; -import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Mapper; -import org.apache.hadoop.util.GenericsUtil; -import org.apache.lucene.analysis.Analyzer; -import org.apache.lucene.analysis.TokenStream; -import org.apache.lucene.analysis.tokenattributes.CharTermAttribute; -import org.apache.mahout.common.ClassUtils; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; -import java.io.StringReader; -import java.util.ArrayList; -import java.util.HashSet; -import java.util.List; -import java.util.Locale; -import java.util.Set; -import java.util.regex.Pattern; - -/** - * Maps over Wikipedia xml format and output all document having the category listed in the input category - * file - * - */ -public class WikipediaDatasetCreatorMapper extends Mapper<LongWritable, Text, Text, Text> { - - private static final Logger log = LoggerFactory.getLogger(WikipediaDatasetCreatorMapper.class); - - private static final Pattern SPACE_NON_ALPHA_PATTERN = Pattern.compile("[\\s\\W]"); - private static final Pattern OPEN_TEXT_TAG_PATTERN = Pattern.compile("<text xml:space=\"preserve\">"); - private static final Pattern CLOSE_TEXT_TAG_PATTERN = Pattern.compile("</text>"); - - private List<String> inputCategories; - private List<Pattern> inputCategoryPatterns; - private boolean exactMatchOnly; - private Analyzer analyzer; - - @Override - protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { - String document = value.toString(); - document = StringEscapeUtils.unescapeHtml4(CLOSE_TEXT_TAG_PATTERN.matcher( - OPEN_TEXT_TAG_PATTERN.matcher(document).replaceFirst("")).replaceAll("")); - String catMatch = findMatchingCategory(document); - if (!"Unknown".equals(catMatch)) { - StringBuilder contents = new StringBuilder(1000); - TokenStream stream = analyzer.tokenStream(catMatch, new StringReader(document)); - CharTermAttribute termAtt = stream.addAttribute(CharTermAttribute.class); - stream.reset(); - while (stream.incrementToken()) { - contents.append(termAtt.buffer(), 0, termAtt.length()).append(' '); - } - context.write( - new Text(SPACE_NON_ALPHA_PATTERN.matcher(catMatch).replaceAll("_")), - new Text(contents.toString())); - stream.end(); - Closeables.close(stream, true); - } - } - - @Override - protected void setup(Context context) throws IOException, InterruptedException { - super.setup(context); - - Configuration conf = context.getConfiguration(); - - if (inputCategories == null) { - Set<String> newCategories = new HashSet<>(); - DefaultStringifier<Set<String>> setStringifier = - new DefaultStringifier<>(conf, GenericsUtil.getClass(newCategories)); - String categoriesStr = conf.get("wikipedia.categories", setStringifier.toString(newCategories)); - Set<String> inputCategoriesSet = setStringifier.fromString(categoriesStr); - inputCategories = new ArrayList<>(inputCategoriesSet); - inputCategoryPatterns = new ArrayList<>(inputCategories.size()); - for (String inputCategory : inputCategories) { - inputCategoryPatterns.add(Pattern.compile(".*\\b" + inputCategory + "\\b.*")); - } - - } - - exactMatchOnly = conf.getBoolean("exact.match.only", false); - - if (analyzer == null) { - String analyzerStr = conf.get("analyzer.class", WikipediaAnalyzer.class.getName()); - analyzer = ClassUtils.instantiateAs(analyzerStr, Analyzer.class); - } - - log.info("Configure: Input Categories size: {} Exact Match: {} Analyzer: {}", - inputCategories.size(), exactMatchOnly, analyzer.getClass().getName()); - } - - private String findMatchingCategory(String document) { - int startIndex = 0; - int categoryIndex; - while ((categoryIndex = document.indexOf("[[Category:", startIndex)) != -1) { - categoryIndex += 11; - int endIndex = document.indexOf("]]", categoryIndex); - if (endIndex >= document.length() || endIndex < 0) { - break; - } - String category = document.substring(categoryIndex, endIndex).toLowerCase(Locale.ENGLISH).trim(); - // categories.add(category.toLowerCase()); - if (exactMatchOnly && inputCategories.contains(category)) { - return category; - } - if (!exactMatchOnly) { - for (int i = 0; i < inputCategories.size(); i++) { - String inputCategory = inputCategories.get(i); - Pattern inputCategoryPattern = inputCategoryPatterns.get(i); - if (inputCategoryPattern.matcher(category).matches()) { // inexact match with word boundary. - return inputCategory; - } - } - } - startIndex = endIndex; - } - return "Unknown"; - } -}
http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorReducer.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorReducer.java b/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorReducer.java deleted file mode 100644 index bf921fc..0000000 --- a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaDatasetCreatorReducer.java +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.text.wikipedia; - -import java.io.IOException; - -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Reducer; - -/** - * Can also be used as a local Combiner - */ -public class WikipediaDatasetCreatorReducer extends Reducer<Text, Text, Text, Text> { - - @Override - protected void reduce(Text key, Iterable<Text> values, Context context) throws IOException, InterruptedException { - // Key is label,word, value is the number of times we've seen this label - // word per local node. Output is the same - for (Text value : values) { - context.write(key, value); - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaMapper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaMapper.java b/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaMapper.java deleted file mode 100644 index abd3a04..0000000 --- a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaMapper.java +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.text.wikipedia; - -import java.io.IOException; -import java.util.HashSet; -import java.util.Locale; -import java.util.Set; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -import org.apache.commons.lang3.StringEscapeUtils; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.io.DefaultStringifier; -import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.Mapper; -import org.apache.hadoop.util.GenericsUtil; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * Maps over Wikipedia xml format and output all document having the category listed in the input category - * file - * - */ -public class WikipediaMapper extends Mapper<LongWritable, Text, Text, Text> { - - private static final Logger log = LoggerFactory.getLogger(WikipediaMapper.class); - - private static final Pattern SPACE_NON_ALPHA_PATTERN = Pattern.compile("[\\s]"); - - private static final String START_DOC = "<text xml:space=\"preserve\">"; - - private static final String END_DOC = "</text>"; - - private static final Pattern TITLE = Pattern.compile("<title>(.*)<\\/title>"); - - private static final String REDIRECT = "<redirect />"; - - private Set<String> inputCategories; - - private boolean exactMatchOnly; - - private boolean all; - - private boolean removeLabels; - - @Override - protected void map(LongWritable key, Text value, Context context) throws IOException, InterruptedException { - - String content = value.toString(); - if (content.contains(REDIRECT)) { - return; - } - String document; - String title; - try { - document = getDocument(content); - title = getTitle(content); - } catch (RuntimeException e) { - // TODO: reporter.getCounter("Wikipedia", "Parse errors").increment(1); - return; - } - - String catMatch = findMatchingCategory(document); - if (!all) { - if ("Unknown".equals(catMatch)) { - return; - } - } - - document = StringEscapeUtils.unescapeHtml4(document); - if (removeLabels) { - document = removeCategoriesFromText(document); - // Reject documents with malformed tags - if (document == null) { - return; - } - } - - // write out in Bayes input style: key: /Category/document_name - String category = "/" + catMatch.toLowerCase(Locale.ENGLISH) + "/" + - SPACE_NON_ALPHA_PATTERN.matcher(title).replaceAll("_"); - - context.write(new Text(category), new Text(document)); - } - - @Override - protected void setup(Context context) throws IOException, InterruptedException { - super.setup(context); - Configuration conf = context.getConfiguration(); - - Set<String> newCategories = new HashSet<>(); - DefaultStringifier<Set<String>> setStringifier = - new DefaultStringifier<>(conf, GenericsUtil.getClass(newCategories)); - - String categoriesStr = conf.get("wikipedia.categories"); - inputCategories = setStringifier.fromString(categoriesStr); - exactMatchOnly = conf.getBoolean("exact.match.only", false); - all = conf.getBoolean("all.files", false); - removeLabels = conf.getBoolean("remove.labels",false); - log.info("Configure: Input Categories size: {} All: {} Exact Match: {} Remove Labels from Text: {}", - inputCategories.size(), all, exactMatchOnly, removeLabels); - } - - private static String getDocument(String xml) { - int start = xml.indexOf(START_DOC) + START_DOC.length(); - int end = xml.indexOf(END_DOC, start); - return xml.substring(start, end); - } - - private static String getTitle(CharSequence xml) { - Matcher m = TITLE.matcher(xml); - return m.find() ? m.group(1) : ""; - } - - private String findMatchingCategory(String document) { - int startIndex = 0; - int categoryIndex; - while ((categoryIndex = document.indexOf("[[Category:", startIndex)) != -1) { - categoryIndex += 11; - int endIndex = document.indexOf("]]", categoryIndex); - if (endIndex >= document.length() || endIndex < 0) { - break; - } - String category = document.substring(categoryIndex, endIndex).toLowerCase(Locale.ENGLISH).trim(); - if (exactMatchOnly && inputCategories.contains(category)) { - return category.toLowerCase(Locale.ENGLISH); - } - if (!exactMatchOnly) { - for (String inputCategory : inputCategories) { - if (category.contains(inputCategory)) { // we have an inexact match - return inputCategory.toLowerCase(Locale.ENGLISH); - } - } - } - startIndex = endIndex; - } - return "Unknown"; - } - - private String removeCategoriesFromText(String document) { - int startIndex = 0; - int categoryIndex; - try { - while ((categoryIndex = document.indexOf("[[Category:", startIndex)) != -1) { - int endIndex = document.indexOf("]]", categoryIndex); - if (endIndex >= document.length() || endIndex < 0) { - break; - } - document = document.replace(document.substring(categoryIndex, endIndex + 2), ""); - if (categoryIndex < document.length()) { - startIndex = categoryIndex; - } else { - break; - } - } - } catch(StringIndexOutOfBoundsException e) { - return null; - } - return document; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaXmlSplitter.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaXmlSplitter.java b/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaXmlSplitter.java deleted file mode 100644 index fc065fe..0000000 --- a/integration/src/main/java/org/apache/mahout/text/wikipedia/WikipediaXmlSplitter.java +++ /dev/null @@ -1,234 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.text.wikipedia; - -import java.io.BufferedWriter; -import java.io.File; -import java.io.FileInputStream; -import java.io.IOException; -import java.io.OutputStreamWriter; -import java.net.URI; -import java.text.DecimalFormat; -import java.text.NumberFormat; - -import org.apache.commons.cli2.CommandLine; -import org.apache.commons.cli2.Group; -import org.apache.commons.cli2.Option; -import org.apache.commons.cli2.OptionException; -import org.apache.commons.cli2.builder.ArgumentBuilder; -import org.apache.commons.cli2.builder.DefaultOptionBuilder; -import org.apache.commons.cli2.builder.GroupBuilder; -import org.apache.commons.cli2.commandline.Parser; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.compress.BZip2Codec; -import org.apache.hadoop.io.compress.CompressionCodec; -import org.apache.mahout.common.CommandLineUtil; -import org.apache.mahout.common.iterator.FileLineIterator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * <p>The Bayes example package provides some helper classes for training the Naive Bayes classifier - * on the Twenty Newsgroups data. See {@code PrepareTwentyNewsgroups} - * for details on running the trainer and - * formatting the Twenty Newsgroups data properly for the training.</p> - * - * <p>The easiest way to prepare the data is to use the ant task in core/build.xml:</p> - * - * <p>{@code ant extract-20news-18828}</p> - * - * <p>This runs the arg line:</p> - * - * <p>{@code -p $\{working.dir\}/20news-18828/ -o $\{working.dir\}/20news-18828-collapse -a $\{analyzer\} -c UTF-8}</p> - * - * <p>To Run the Wikipedia examples (assumes you've built the Mahout Job jar):</p> - * - * <ol> - * <li>Download the Wikipedia Dataset. Use the Ant target: {@code ant enwiki-files}</li> - * <li>Chunk the data using the WikipediaXmlSplitter (from the Hadoop home): - * {@code bin/hadoop jar $MAHOUT_HOME/target/mahout-examples-0.x - * org.apache.mahout.classifier.bayes.WikipediaXmlSplitter - * -d $MAHOUT_HOME/examples/temp/enwiki-latest-pages-articles.xml - * -o $MAHOUT_HOME/examples/work/wikipedia/chunks/ -c 64}</li> - * </ol> - */ -public final class WikipediaXmlSplitter { - - private static final Logger log = LoggerFactory.getLogger(WikipediaXmlSplitter.class); - - private WikipediaXmlSplitter() { } - - public static void main(String[] args) throws IOException { - DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); - ArgumentBuilder abuilder = new ArgumentBuilder(); - GroupBuilder gbuilder = new GroupBuilder(); - - Option dumpFileOpt = obuilder.withLongName("dumpFile").withRequired(true).withArgument( - abuilder.withName("dumpFile").withMinimum(1).withMaximum(1).create()).withDescription( - "The path to the wikipedia dump file (.bz2 or uncompressed)").withShortName("d").create(); - - Option outputDirOpt = obuilder.withLongName("outputDir").withRequired(true).withArgument( - abuilder.withName("outputDir").withMinimum(1).withMaximum(1).create()).withDescription( - "The output directory to place the splits in:\n" - + "local files:\n\t/var/data/wikipedia-xml-chunks or\n\tfile:///var/data/wikipedia-xml-chunks\n" - + "Hadoop DFS:\n\thdfs://wikipedia-xml-chunks\n" - + "AWS S3 (blocks):\n\ts3://bucket-name/wikipedia-xml-chunks\n" - + "AWS S3 (native files):\n\ts3n://bucket-name/wikipedia-xml-chunks\n") - - .withShortName("o").create(); - - Option s3IdOpt = obuilder.withLongName("s3ID").withRequired(false).withArgument( - abuilder.withName("s3Id").withMinimum(1).withMaximum(1).create()).withDescription("Amazon S3 ID key") - .withShortName("i").create(); - Option s3SecretOpt = obuilder.withLongName("s3Secret").withRequired(false).withArgument( - abuilder.withName("s3Secret").withMinimum(1).withMaximum(1).create()).withDescription( - "Amazon S3 secret key").withShortName("s").create(); - - Option chunkSizeOpt = obuilder.withLongName("chunkSize").withRequired(true).withArgument( - abuilder.withName("chunkSize").withMinimum(1).withMaximum(1).create()).withDescription( - "The Size of the chunk, in megabytes").withShortName("c").create(); - Option numChunksOpt = obuilder - .withLongName("numChunks") - .withRequired(false) - .withArgument(abuilder.withName("numChunks").withMinimum(1).withMaximum(1).create()) - .withDescription( - "The maximum number of chunks to create. If specified, program will only create a subset of the chunks") - .withShortName("n").create(); - Group group = gbuilder.withName("Options").withOption(dumpFileOpt).withOption(outputDirOpt).withOption( - chunkSizeOpt).withOption(numChunksOpt).withOption(s3IdOpt).withOption(s3SecretOpt).create(); - - Parser parser = new Parser(); - parser.setGroup(group); - CommandLine cmdLine; - try { - cmdLine = parser.parse(args); - } catch (OptionException e) { - log.error("Error while parsing options", e); - CommandLineUtil.printHelp(group); - return; - } - - Configuration conf = new Configuration(); - String dumpFilePath = (String) cmdLine.getValue(dumpFileOpt); - String outputDirPath = (String) cmdLine.getValue(outputDirOpt); - - if (cmdLine.hasOption(s3IdOpt)) { - String id = (String) cmdLine.getValue(s3IdOpt); - conf.set("fs.s3n.awsAccessKeyId", id); - conf.set("fs.s3.awsAccessKeyId", id); - } - if (cmdLine.hasOption(s3SecretOpt)) { - String secret = (String) cmdLine.getValue(s3SecretOpt); - conf.set("fs.s3n.awsSecretAccessKey", secret); - conf.set("fs.s3.awsSecretAccessKey", secret); - } - // do not compute crc file when using local FS - conf.set("fs.file.impl", "org.apache.hadoop.fs.RawLocalFileSystem"); - FileSystem fs = FileSystem.get(URI.create(outputDirPath), conf); - - int chunkSize = 1024 * 1024 * Integer.parseInt((String) cmdLine.getValue(chunkSizeOpt)); - - int numChunks = Integer.MAX_VALUE; - if (cmdLine.hasOption(numChunksOpt)) { - numChunks = Integer.parseInt((String) cmdLine.getValue(numChunksOpt)); - } - - String header = "<mediawiki xmlns=\"http://www.mediawiki.org/xml/export-0.3/\" " - + "xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\" " - + "xsi:schemaLocation=\"http://www.mediawiki.org/xml/export-0.3/ " - + "http://www.mediawiki.org/xml/export-0.3.xsd\" " + "version=\"0.3\" " - + "xml:lang=\"en\">\n" + " <siteinfo>\n" + "<sitename>Wikipedia</sitename>\n" - + " <base>http://en.wikipedia.org/wiki/Main_Page</base>\n" - + " <generator>MediaWiki 1.13alpha</generator>\n" + " <case>first-letter</case>\n" - + " <namespaces>\n" + " <namespace key=\"-2\">Media</namespace>\n" - + " <namespace key=\"-1\">Special</namespace>\n" + " <namespace key=\"0\" />\n" - + " <namespace key=\"1\">Talk</namespace>\n" - + " <namespace key=\"2\">User</namespace>\n" - + " <namespace key=\"3\">User talk</namespace>\n" - + " <namespace key=\"4\">Wikipedia</namespace>\n" - + " <namespace key=\"5\">Wikipedia talk</namespace>\n" - + " <namespace key=\"6\">Image</namespace>\n" - + " <namespace key=\"7\">Image talk</namespace>\n" - + " <namespace key=\"8\">MediaWiki</namespace>\n" - + " <namespace key=\"9\">MediaWiki talk</namespace>\n" - + " <namespace key=\"10\">Template</namespace>\n" - + " <namespace key=\"11\">Template talk</namespace>\n" - + " <namespace key=\"12\">Help</namespace>\n" - + " <namespace key=\"13\">Help talk</namespace>\n" - + " <namespace key=\"14\">Category</namespace>\n" - + " <namespace key=\"15\">Category talk</namespace>\n" - + " <namespace key=\"100\">Portal</namespace>\n" - + " <namespace key=\"101\">Portal talk</namespace>\n" + " </namespaces>\n" - + " </siteinfo>\n"; - - StringBuilder content = new StringBuilder(); - content.append(header); - NumberFormat decimalFormatter = new DecimalFormat("0000"); - File dumpFile = new File(dumpFilePath); - - // If the specified path for the input file is incorrect, return immediately - if (!dumpFile.exists()) { - log.error("Input file path {} doesn't exist", dumpFilePath); - return; - } - - FileLineIterator it; - if (dumpFilePath.endsWith(".bz2")) { - // default compression format from http://download.wikimedia.org - CompressionCodec codec = new BZip2Codec(); - it = new FileLineIterator(codec.createInputStream(new FileInputStream(dumpFile))); - } else { - // assume the user has previously de-compressed the dump file - it = new FileLineIterator(dumpFile); - } - int fileNumber = 0; - while (it.hasNext()) { - String thisLine = it.next(); - if (thisLine.trim().startsWith("<page>")) { - boolean end = false; - while (!thisLine.trim().startsWith("</page>")) { - content.append(thisLine).append('\n'); - if (it.hasNext()) { - thisLine = it.next(); - } else { - end = true; - break; - } - } - content.append(thisLine).append('\n'); - - if (content.length() > chunkSize || end) { - content.append("</mediawiki>"); - fileNumber++; - String filename = outputDirPath + "/chunk-" + decimalFormatter.format(fileNumber) + ".xml"; - try (BufferedWriter chunkWriter = - new BufferedWriter(new OutputStreamWriter(fs.create(new Path(filename)), "UTF-8"))) { - chunkWriter.write(content.toString(), 0, content.length()); - } - if (fileNumber >= numChunks) { - break; - } - content = new StringBuilder(); - content.append(header); - } - } - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/text/wikipedia/XmlInputFormat.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/text/wikipedia/XmlInputFormat.java b/integration/src/main/java/org/apache/mahout/text/wikipedia/XmlInputFormat.java deleted file mode 100644 index afd350f..0000000 --- a/integration/src/main/java/org/apache/mahout/text/wikipedia/XmlInputFormat.java +++ /dev/null @@ -1,164 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.text.wikipedia; - -import com.google.common.io.Closeables; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FSDataInputStream; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.DataOutputBuffer; -import org.apache.hadoop.io.LongWritable; -import org.apache.hadoop.io.Text; -import org.apache.hadoop.mapreduce.InputSplit; -import org.apache.hadoop.mapreduce.RecordReader; -import org.apache.hadoop.mapreduce.TaskAttemptContext; -import org.apache.hadoop.mapreduce.lib.input.FileSplit; -import org.apache.hadoop.mapreduce.lib.input.TextInputFormat; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -import java.io.IOException; - -/** - * Reads records that are delimited by a specific begin/end tag. - */ -public class XmlInputFormat extends TextInputFormat { - - private static final Logger log = LoggerFactory.getLogger(XmlInputFormat.class); - - public static final String START_TAG_KEY = "xmlinput.start"; - public static final String END_TAG_KEY = "xmlinput.end"; - - @Override - public RecordReader<LongWritable, Text> createRecordReader(InputSplit split, TaskAttemptContext context) { - try { - return new XmlRecordReader((FileSplit) split, context.getConfiguration()); - } catch (IOException ioe) { - log.warn("Error while creating XmlRecordReader", ioe); - return null; - } - } - - /** - * XMLRecordReader class to read through a given xml document to output xml blocks as records as specified - * by the start tag and end tag - * - */ - public static class XmlRecordReader extends RecordReader<LongWritable, Text> { - - private final byte[] startTag; - private final byte[] endTag; - private final long start; - private final long end; - private final FSDataInputStream fsin; - private final DataOutputBuffer buffer = new DataOutputBuffer(); - private LongWritable currentKey; - private Text currentValue; - - public XmlRecordReader(FileSplit split, Configuration conf) throws IOException { - startTag = conf.get(START_TAG_KEY).getBytes(Charsets.UTF_8); - endTag = conf.get(END_TAG_KEY).getBytes(Charsets.UTF_8); - - // open the file and seek to the start of the split - start = split.getStart(); - end = start + split.getLength(); - Path file = split.getPath(); - FileSystem fs = file.getFileSystem(conf); - fsin = fs.open(split.getPath()); - fsin.seek(start); - } - - private boolean next(LongWritable key, Text value) throws IOException { - if (fsin.getPos() < end && readUntilMatch(startTag, false)) { - try { - buffer.write(startTag); - if (readUntilMatch(endTag, true)) { - key.set(fsin.getPos()); - value.set(buffer.getData(), 0, buffer.getLength()); - return true; - } - } finally { - buffer.reset(); - } - } - return false; - } - - @Override - public void close() throws IOException { - Closeables.close(fsin, true); - } - - @Override - public float getProgress() throws IOException { - return (fsin.getPos() - start) / (float) (end - start); - } - - private boolean readUntilMatch(byte[] match, boolean withinBlock) throws IOException { - int i = 0; - while (true) { - int b = fsin.read(); - // end of file: - if (b == -1) { - return false; - } - // save to buffer: - if (withinBlock) { - buffer.write(b); - } - - // check if we're matching: - if (b == match[i]) { - i++; - if (i >= match.length) { - return true; - } - } else { - i = 0; - } - // see if we've passed the stop point: - if (!withinBlock && i == 0 && fsin.getPos() >= end) { - return false; - } - } - } - - @Override - public LongWritable getCurrentKey() throws IOException, InterruptedException { - return currentKey; - } - - @Override - public Text getCurrentValue() throws IOException, InterruptedException { - return currentValue; - } - - @Override - public void initialize(InputSplit split, TaskAttemptContext context) throws IOException, InterruptedException { - } - - @Override - public boolean nextKeyValue() throws IOException, InterruptedException { - currentKey = new LongWritable(); - currentValue = new Text(); - return next(currentKey, currentValue); - } - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/utils/Bump125.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/utils/Bump125.java b/integration/src/main/java/org/apache/mahout/utils/Bump125.java deleted file mode 100644 index 1c55090..0000000 --- a/integration/src/main/java/org/apache/mahout/utils/Bump125.java +++ /dev/null @@ -1,62 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.utils; - -/** - * Helps with making nice intervals at arbitrary scale. - * - * One use case is where we are producing progress or error messages every time an incoming - * record is received. It is generally bad form to produce a message for <i>every</i> input - * so it would be better to produce a message for each of the first 10 records, then every - * other record up to 20 and then every 5 records up to 50 and then every 10 records up to 100, - * more or less. The pattern can now repeat scaled up by 100. The total number of messages will scale - * with the log of the number of input lines which is much more survivable than direct output - * and because early records all get messages, we get indications early. - */ -public class Bump125 { - private static final int[] BUMPS = {1, 2, 5}; - - static int scale(double value, double base) { - double scale = value / base; - // scan for correct step - int i = 0; - while (i < BUMPS.length - 1 && BUMPS[i + 1] <= scale) { - i++; - } - return BUMPS[i]; - } - - static long base(double value) { - return Math.max(1, (long) Math.pow(10, (int) Math.floor(Math.log10(value)))); - } - - private long counter = 0; - - public long increment() { - long delta; - if (counter >= 10) { - long base = base(counter / 4.0); - int scale = scale(counter / 4.0, base); - delta = base * scale; - } else { - delta = 1; - } - counter += delta; - return counter; - } -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/utils/MatrixDumper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/utils/MatrixDumper.java b/integration/src/main/java/org/apache/mahout/utils/MatrixDumper.java deleted file mode 100644 index f63de83..0000000 --- a/integration/src/main/java/org/apache/mahout/utils/MatrixDumper.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.utils; - -import java.io.File; -import java.io.FileOutputStream; -import java.io.IOException; -import java.io.OutputStream; -import java.io.PrintStream; -import java.util.List; -import java.util.Map; - -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.util.ToolRunner; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator; -import org.apache.mahout.math.Matrix; -import org.apache.mahout.math.MatrixWritable; - -/** - * Export a Matrix in various text formats: - * * CSV file - * - * Input format: Hadoop SequenceFile with Text key and MatrixWritable value, 1 pair - * TODO: - * Needs class for key value- should not hard-code to Text. - * Options for row and column headers- stats software can be picky. - * Assumes only one matrix in a file. - */ -public final class MatrixDumper extends AbstractJob { - - private MatrixDumper() { } - - public static void main(String[] args) throws Exception { - ToolRunner.run(new MatrixDumper(), args); - } - - @Override - public int run(String[] args) throws Exception { - - addInputOption(); - addOption("output", "o", "Output path", null); // AbstractJob output feature requires param - Map<String, List<String>> parsedArgs = parseArguments(args); - if (parsedArgs == null) { - return -1; - } - String outputFile = hasOption("output") ? getOption("output") : null; - exportCSV(getInputPath(), outputFile, false); - return 0; - } - - private static void exportCSV(Path inputPath, String outputFile, boolean doLabels) throws IOException { - SequenceFileValueIterator<MatrixWritable> it = - new SequenceFileValueIterator<>(inputPath, true, new Configuration()); - Matrix m = it.next().get(); - it.close(); - PrintStream ps = getPrintStream(outputFile); - String[] columnLabels = getLabels(m.numCols(), m.getColumnLabelBindings(), "col"); - String[] rowLabels = getLabels(m.numRows(), m.getRowLabelBindings(), "row"); - if (doLabels) { - ps.print("rowid,"); - ps.print(columnLabels[0]); - for (int c = 1; c < m.numCols(); c++) { - ps.print(',' + columnLabels[c]); - } - ps.println(); - } - for (int r = 0; r < m.numRows(); r++) { - if (doLabels) { - ps.print(rowLabels[0] + ','); - } - ps.print(Double.toString(m.getQuick(r,0))); - for (int c = 1; c < m.numCols(); c++) { - ps.print(","); - ps.print(Double.toString(m.getQuick(r,c))); - } - ps.println(); - } - if (ps != System.out) { - ps.close(); - } - } - - private static PrintStream getPrintStream(String outputPath) throws IOException { - if (outputPath == null) { - return System.out; - } - File outputFile = new File(outputPath); - if (outputFile.exists()) { - outputFile.delete(); - } - outputFile.createNewFile(); - OutputStream os = new FileOutputStream(outputFile); - return new PrintStream(os, false, Charsets.UTF_8.displayName()); - } - - /** - * return the label set, sorted by matrix order - * if there are no labels, fabricate them using the starter string - * @param length - */ - private static String[] getLabels(int length, Map<String,Integer> labels, String start) { - if (labels != null) { - return sortLabels(labels); - } - String[] sorted = new String[length]; - for (int i = 1; i <= length; i++) { - sorted[i] = start + i; - } - return sorted; - } - - private static String[] sortLabels(Map<String,Integer> labels) { - String[] sorted = new String[labels.size()]; - for (Map.Entry<String,Integer> entry : labels.entrySet()) { - sorted[entry.getValue()] = entry.getKey(); - } - return sorted; - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java b/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java deleted file mode 100644 index e01868a..0000000 --- a/integration/src/main/java/org/apache/mahout/utils/SequenceFileDumper.java +++ /dev/null @@ -1,168 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.utils; - -import java.io.File; -import java.io.OutputStreamWriter; -import java.io.Writer; -import java.util.ArrayList; -import java.util.List; - -import com.google.common.io.Closeables; -import com.google.common.io.Files; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.FileUtil; -import org.apache.hadoop.fs.Path; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator; -import org.apache.mahout.math.list.IntArrayList; -import org.apache.mahout.math.map.OpenObjectIntHashMap; - -public final class SequenceFileDumper extends AbstractJob { - - public SequenceFileDumper() { - setConf(new Configuration()); - } - - @Override - public int run(String[] args) throws Exception { - - addInputOption(); - addOutputOption(); - addOption("substring", "b", "The number of chars to print out per value", false); - addOption(buildOption("count", "c", "Report the count only", false, false, null)); - addOption("numItems", "n", "Output at most <n> key value pairs", false); - addOption(buildOption("facets", "fa", "Output the counts per key. Note, if there are a lot of unique keys, " - + "this can take up a fair amount of memory", false, false, null)); - addOption(buildOption("quiet", "q", "Print only file contents.", false, false, null)); - - if (parseArguments(args, false, true) == null) { - return -1; - } - - Path[] pathArr; - Configuration conf = new Configuration(); - Path input = getInputPath(); - FileSystem fs = input.getFileSystem(conf); - if (fs.getFileStatus(input).isDir()) { - pathArr = FileUtil.stat2Paths(fs.listStatus(input, PathFilters.logsCRCFilter())); - } else { - pathArr = new Path[1]; - pathArr[0] = input; - } - - - Writer writer; - boolean shouldClose; - if (hasOption("output")) { - shouldClose = true; - writer = Files.newWriter(new File(getOption("output")), Charsets.UTF_8); - } else { - shouldClose = false; - writer = new OutputStreamWriter(System.out, Charsets.UTF_8); - } - try { - for (Path path : pathArr) { - if (!hasOption("quiet")) { - writer.append("Input Path: ").append(String.valueOf(path)).append('\n'); - } - - int sub = Integer.MAX_VALUE; - if (hasOption("substring")) { - sub = Integer.parseInt(getOption("substring")); - } - boolean countOnly = hasOption("count"); - SequenceFileIterator<?, ?> iterator = new SequenceFileIterator<>(path, true, conf); - if (!hasOption("quiet")) { - writer.append("Key class: ").append(iterator.getKeyClass().toString()); - writer.append(" Value Class: ").append(iterator.getValueClass().toString()).append('\n'); - } - OpenObjectIntHashMap<String> facets = null; - if (hasOption("facets")) { - facets = new OpenObjectIntHashMap<>(); - } - long count = 0; - if (countOnly) { - while (iterator.hasNext()) { - Pair<?, ?> record = iterator.next(); - String key = record.getFirst().toString(); - if (facets != null) { - facets.adjustOrPutValue(key, 1, 1); //either insert or add 1 - } - count++; - } - writer.append("Count: ").append(String.valueOf(count)).append('\n'); - } else { - long numItems = Long.MAX_VALUE; - if (hasOption("numItems")) { - numItems = Long.parseLong(getOption("numItems")); - if (!hasOption("quiet")) { - writer.append("Max Items to dump: ").append(String.valueOf(numItems)).append("\n"); - } - } - while (iterator.hasNext() && count < numItems) { - Pair<?, ?> record = iterator.next(); - String key = record.getFirst().toString(); - writer.append("Key: ").append(key); - String str = record.getSecond().toString(); - writer.append(": Value: ").append(str.length() > sub - ? str.substring(0, sub) : str); - writer.write('\n'); - if (facets != null) { - facets.adjustOrPutValue(key, 1, 1); //either insert or add 1 - } - count++; - } - if (!hasOption("quiet")) { - writer.append("Count: ").append(String.valueOf(count)).append('\n'); - } - } - if (facets != null) { - List<String> keyList = new ArrayList<>(facets.size()); - - IntArrayList valueList = new IntArrayList(facets.size()); - facets.pairsSortedByKey(keyList, valueList); - writer.append("-----Facets---\n"); - writer.append("Key\t\tCount\n"); - int i = 0; - for (String key : keyList) { - writer.append(key).append("\t\t").append(String.valueOf(valueList.get(i++))).append('\n'); - } - } - } - writer.flush(); - - } finally { - if (shouldClose) { - Closeables.close(writer, false); - } - } - - - return 0; - } - - public static void main(String[] args) throws Exception { - new SequenceFileDumper().run(args); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/utils/SplitInput.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/utils/SplitInput.java b/integration/src/main/java/org/apache/mahout/utils/SplitInput.java deleted file mode 100644 index 6178f80..0000000 --- a/integration/src/main/java/org/apache/mahout/utils/SplitInput.java +++ /dev/null @@ -1,673 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.utils; - -import java.io.BufferedReader; -import java.io.IOException; -import java.io.InputStreamReader; -import java.io.OutputStreamWriter; -import java.io.Writer; -import java.nio.charset.Charset; -import java.util.BitSet; - -import com.google.common.base.Preconditions; -import org.apache.commons.cli2.OptionException; -import org.apache.commons.io.Charsets; -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileStatus; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.SequenceFile; -import org.apache.hadoop.io.Writable; -import org.apache.hadoop.util.ToolRunner; -import org.apache.mahout.common.AbstractJob; -import org.apache.mahout.common.CommandLineUtil; -import org.apache.mahout.common.HadoopUtil; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.common.commandline.DefaultOptionCreator; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator; -import org.apache.mahout.math.jet.random.sampling.RandomSampler; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * A utility for splitting files in the input format used by the Bayes - * classifiers or anything else that has one item per line or SequenceFiles (key/value) - * into training and test sets in order to perform cross-validation. - * <p/> - * <p/> - * This class can be used to split directories of files or individual files into - * training and test sets using a number of different methods. - * <p/> - * When executed via {@link #splitDirectory(Path)} or {@link #splitFile(Path)}, - * the lines read from one or more, input files are written to files of the same - * name into the directories specified by the - * {@link #setTestOutputDirectory(Path)} and - * {@link #setTrainingOutputDirectory(Path)} methods. - * <p/> - * The composition of the test set is determined using one of the following - * approaches: - * <ul> - * <li>A contiguous set of items can be chosen from the input file(s) using the - * {@link #setTestSplitSize(int)} or {@link #setTestSplitPct(int)} methods. - * {@link #setTestSplitSize(int)} allocates a fixed number of items, while - * {@link #setTestSplitPct(int)} allocates a percentage of the original input, - * rounded up to the nearest integer. {@link #setSplitLocation(int)} is used to - * control the position in the input from which the test data is extracted and - * is described further below.</li> - * <li>A random sampling of items can be chosen from the input files(s) using - * the {@link #setTestRandomSelectionSize(int)} or - * {@link #setTestRandomSelectionPct(int)} methods, each choosing a fixed test - * set size or percentage of the input set size as described above. The - * {@link RandomSampler} class from {@code mahout-math} is used to create a sample - * of the appropriate size.</li> - * </ul> - * <p/> - * Any one of the methods above can be used to control the size of the test set. - * If multiple methods are called, a runtime exception will be thrown at - * execution time. - * <p/> - * The {@link #setSplitLocation(int)} method is passed an integer from 0 to 100 - * (inclusive) which is translated into the position of the start of the test - * data within the input file. - * <p/> - * Given: - * <ul> - * <li>an input file of 1500 lines</li> - * <li>a desired test data size of 10 percent</li> - * </ul> - * <p/> - * <ul> - * <li>A split location of 0 will cause the first 150 items appearing in the - * input set to be written to the test set.</li> - * <li>A split location of 25 will cause items 375-525 to be written to the test - * set.</li> - * <li>A split location of 100 will cause the last 150 items in the input to be - * written to the test set</li> - * </ul> - * The start of the split will always be adjusted forwards in order to ensure - * that the desired test set size is allocated. Split location has no effect is - * random sampling is employed. - */ -public class SplitInput extends AbstractJob { - - private static final Logger log = LoggerFactory.getLogger(SplitInput.class); - - private int testSplitSize = -1; - private int testSplitPct = -1; - private int splitLocation = 100; - private int testRandomSelectionSize = -1; - private int testRandomSelectionPct = -1; - private int keepPct = 100; - private Charset charset = Charsets.UTF_8; - private boolean useSequence; - private boolean useMapRed; - - private Path inputDirectory; - private Path trainingOutputDirectory; - private Path testOutputDirectory; - private Path mapRedOutputDirectory; - - private SplitCallback callback; - - @Override - public int run(String[] args) throws Exception { - - if (parseArgs(args)) { - splitDirectory(); - } - return 0; - } - - public static void main(String[] args) throws Exception { - ToolRunner.run(new Configuration(), new SplitInput(), args); - } - - /** - * Configure this instance based on the command-line arguments contained within provided array. - * Calls {@link #validate()} to ensure consistency of configuration. - * - * @return true if the arguments were parsed successfully and execution should proceed. - * @throws Exception if there is a problem parsing the command-line arguments or the particular - * combination would violate class invariants. - */ - private boolean parseArgs(String[] args) throws Exception { - - addInputOption(); - addOption("trainingOutput", "tr", "The training data output directory", false); - addOption("testOutput", "te", "The test data output directory", false); - addOption("testSplitSize", "ss", "The number of documents held back as test data for each category", false); - addOption("testSplitPct", "sp", "The % of documents held back as test data for each category", false); - addOption("splitLocation", "sl", "Location for start of test data expressed as a percentage of the input file " - + "size (0=start, 50=middle, 100=end", false); - addOption("randomSelectionSize", "rs", "The number of items to be randomly selected as test data ", false); - addOption("randomSelectionPct", "rp", "Percentage of items to be randomly selected as test data when using " - + "mapreduce mode", false); - addOption("charset", "c", "The name of the character encoding of the input files (not needed if using " - + "SequenceFiles)", false); - addOption(buildOption("sequenceFiles", "seq", "Set if the input files are sequence files. Default is false", - false, false, "false")); - addOption(DefaultOptionCreator.methodOption().create()); - addOption(DefaultOptionCreator.overwriteOption().create()); - //TODO: extend this to sequential mode - addOption("keepPct", "k", "The percentage of total data to keep in map-reduce mode, the rest will be ignored. " - + "Default is 100%", false); - addOption("mapRedOutputDir", "mro", "Output directory for map reduce jobs", false); - - if (parseArguments(args) == null) { - return false; - } - - try { - inputDirectory = getInputPath(); - - useMapRed = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(DefaultOptionCreator.MAPREDUCE_METHOD); - - if (useMapRed) { - if (!hasOption("randomSelectionPct")) { - throw new OptionException(getCLIOption("randomSelectionPct"), - "must set randomSelectionPct when mapRed option is used"); - } - if (!hasOption("mapRedOutputDir")) { - throw new OptionException(getCLIOption("mapRedOutputDir"), - "mapRedOutputDir must be set when mapRed option is used"); - } - mapRedOutputDirectory = new Path(getOption("mapRedOutputDir")); - if (hasOption("keepPct")) { - keepPct = Integer.parseInt(getOption("keepPct")); - } - if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { - HadoopUtil.delete(getConf(), mapRedOutputDirectory); - } - } else { - if (!hasOption("trainingOutput") - || !hasOption("testOutput")) { - throw new OptionException(getCLIOption("trainingOutput"), - "trainingOutput and testOutput must be set if mapRed option is not used"); - } - if (!hasOption("testSplitSize") - && !hasOption("testSplitPct") - && !hasOption("randomSelectionPct") - && !hasOption("randomSelectionSize")) { - throw new OptionException(getCLIOption("testSplitSize"), - "must set one of test split size/percentage or randomSelectionSize/percentage"); - } - - trainingOutputDirectory = new Path(getOption("trainingOutput")); - testOutputDirectory = new Path(getOption("testOutput")); - FileSystem fs = trainingOutputDirectory.getFileSystem(getConf()); - if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { - HadoopUtil.delete(fs.getConf(), trainingOutputDirectory); - HadoopUtil.delete(fs.getConf(), testOutputDirectory); - } - fs.mkdirs(trainingOutputDirectory); - fs.mkdirs(testOutputDirectory); - } - - if (hasOption("charset")) { - charset = Charset.forName(getOption("charset")); - } - - if (hasOption("testSplitSize") && hasOption("testSplitPct")) { - throw new OptionException(getCLIOption("testSplitPct"), "must have either split size or split percentage " - + "option, not BOTH"); - } - - if (hasOption("testSplitSize")) { - setTestSplitSize(Integer.parseInt(getOption("testSplitSize"))); - } - - if (hasOption("testSplitPct")) { - setTestSplitPct(Integer.parseInt(getOption("testSplitPct"))); - } - - if (hasOption("splitLocation")) { - setSplitLocation(Integer.parseInt(getOption("splitLocation"))); - } - - if (hasOption("randomSelectionSize")) { - setTestRandomSelectionSize(Integer.parseInt(getOption("randomSelectionSize"))); - } - - if (hasOption("randomSelectionPct")) { - setTestRandomSelectionPct(Integer.parseInt(getOption("randomSelectionPct"))); - } - - useSequence = hasOption("sequenceFiles"); - - } catch (OptionException e) { - log.error("Command-line option Exception", e); - CommandLineUtil.printHelp(getGroup()); - return false; - } - - validate(); - return true; - } - - /** - * Perform a split on directory specified by {@link #setInputDirectory(Path)} by calling {@link #splitFile(Path)} - * on each file found within that directory. - */ - public void splitDirectory() throws IOException, ClassNotFoundException, InterruptedException { - this.splitDirectory(inputDirectory); - } - - /** - * Perform a split on the specified directory by calling {@link #splitFile(Path)} on each file found within that - * directory. - */ - public void splitDirectory(Path inputDir) throws IOException, ClassNotFoundException, InterruptedException { - Configuration conf = getConf(); - splitDirectory(conf, inputDir); - } - - /* - * See also splitDirectory(Path inputDir) - * */ - public void splitDirectory(Configuration conf, Path inputDir) - throws IOException, ClassNotFoundException, InterruptedException { - FileSystem fs = inputDir.getFileSystem(conf); - if (fs.getFileStatus(inputDir) == null) { - throw new IOException(inputDir + " does not exist"); - } - if (!fs.getFileStatus(inputDir).isDir()) { - throw new IOException(inputDir + " is not a directory"); - } - - if (useMapRed) { - SplitInputJob.run(conf, inputDir, mapRedOutputDirectory, - keepPct, testRandomSelectionPct); - } else { - // input dir contains one file per category. - FileStatus[] fileStats = fs.listStatus(inputDir, PathFilters.logsCRCFilter()); - for (FileStatus inputFile : fileStats) { - if (!inputFile.isDir()) { - splitFile(inputFile.getPath()); - } - } - } - } - - /** - * Perform a split on the specified input file. Results will be written to files of the same name in the specified - * training and test output directories. The {@link #validate()} method is called prior to executing the split. - */ - public void splitFile(Path inputFile) throws IOException { - Configuration conf = getConf(); - FileSystem fs = inputFile.getFileSystem(conf); - if (fs.getFileStatus(inputFile) == null) { - throw new IOException(inputFile + " does not exist"); - } - if (fs.getFileStatus(inputFile).isDir()) { - throw new IOException(inputFile + " is a directory"); - } - - validate(); - - Path testOutputFile = new Path(testOutputDirectory, inputFile.getName()); - Path trainingOutputFile = new Path(trainingOutputDirectory, inputFile.getName()); - - int lineCount = countLines(fs, inputFile, charset); - - log.info("{} has {} lines", inputFile.getName(), lineCount); - - int testSplitStart = 0; - int testSplitSize = this.testSplitSize; // don't modify state - BitSet randomSel = null; - - if (testRandomSelectionPct > 0 || testRandomSelectionSize > 0) { - testSplitSize = this.testRandomSelectionSize; - - if (testRandomSelectionPct > 0) { - testSplitSize = Math.round(lineCount * testRandomSelectionPct / 100.0f); - } - log.info("{} test split size is {} based on random selection percentage {}", - inputFile.getName(), testSplitSize, testRandomSelectionPct); - long[] ridx = new long[testSplitSize]; - RandomSampler.sample(testSplitSize, lineCount - 1, testSplitSize, 0, ridx, 0, RandomUtils.getRandom()); - randomSel = new BitSet(lineCount); - for (long idx : ridx) { - randomSel.set((int) idx + 1); - } - } else { - if (testSplitPct > 0) { // calculate split size based on percentage - testSplitSize = Math.round(lineCount * testSplitPct / 100.0f); - log.info("{} test split size is {} based on percentage {}", - inputFile.getName(), testSplitSize, testSplitPct); - } else { - log.info("{} test split size is {}", inputFile.getName(), testSplitSize); - } - - if (splitLocation > 0) { // calculate start of split based on percentage - testSplitStart = Math.round(lineCount * splitLocation / 100.0f); - if (lineCount - testSplitStart < testSplitSize) { - // adjust split start downwards based on split size. - testSplitStart = lineCount - testSplitSize; - } - log.info("{} test split start is {} based on split location {}", - inputFile.getName(), testSplitStart, splitLocation); - } - - if (testSplitStart < 0) { - throw new IllegalArgumentException("test split size for " + inputFile + " is too large, it would produce an " - + "empty training set from the initial set of " + lineCount + " examples"); - } else if (lineCount - testSplitSize < testSplitSize) { - log.warn("Test set size for {} may be too large, {} is larger than the number of " - + "lines remaining in the training set: {}", - inputFile, testSplitSize, lineCount - testSplitSize); - } - } - int trainCount = 0; - int testCount = 0; - if (!useSequence) { - try (BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset)); - Writer trainingWriter = new OutputStreamWriter(fs.create(trainingOutputFile), charset); - Writer testWriter = new OutputStreamWriter(fs.create(testOutputFile), charset)){ - - String line; - int pos = 0; - while ((line = reader.readLine()) != null) { - pos++; - - Writer writer; - if (testRandomSelectionPct > 0) { // Randomly choose - writer = randomSel.get(pos) ? testWriter : trainingWriter; - } else { // Choose based on location - writer = pos > testSplitStart ? testWriter : trainingWriter; - } - - if (writer == testWriter) { - if (testCount >= testSplitSize) { - writer = trainingWriter; - } else { - testCount++; - } - } - if (writer == trainingWriter) { - trainCount++; - } - writer.write(line); - writer.write('\n'); - } - - } - } else { - try (SequenceFileIterator<Writable, Writable> iterator = - new SequenceFileIterator<>(inputFile, false, fs.getConf()); - SequenceFile.Writer trainingWriter = SequenceFile.createWriter(fs, fs.getConf(), trainingOutputFile, - iterator.getKeyClass(), iterator.getValueClass()); - SequenceFile.Writer testWriter = SequenceFile.createWriter(fs, fs.getConf(), testOutputFile, - iterator.getKeyClass(), iterator.getValueClass())) { - - int pos = 0; - while (iterator.hasNext()) { - pos++; - SequenceFile.Writer writer; - if (testRandomSelectionPct > 0) { // Randomly choose - writer = randomSel.get(pos) ? testWriter : trainingWriter; - } else { // Choose based on location - writer = pos > testSplitStart ? testWriter : trainingWriter; - } - - if (writer == testWriter) { - if (testCount >= testSplitSize) { - writer = trainingWriter; - } else { - testCount++; - } - } - if (writer == trainingWriter) { - trainCount++; - } - Pair<Writable, Writable> pair = iterator.next(); - writer.append(pair.getFirst(), pair.getSecond()); - } - - } - } - log.info("file: {}, input: {} train: {}, test: {} starting at {}", - inputFile.getName(), lineCount, trainCount, testCount, testSplitStart); - - // testing; - if (callback != null) { - callback.splitComplete(inputFile, lineCount, trainCount, testCount, testSplitStart); - } - } - - public int getTestSplitSize() { - return testSplitSize; - } - - public void setTestSplitSize(int testSplitSize) { - this.testSplitSize = testSplitSize; - } - - public int getTestSplitPct() { - return testSplitPct; - } - - /** - * Sets the percentage of the input data to allocate to the test split - * - * @param testSplitPct a value between 0 and 100 inclusive. - */ - public void setTestSplitPct(int testSplitPct) { - this.testSplitPct = testSplitPct; - } - - /** - * Sets the percentage of the input data to keep in a map reduce split input job - * - * @param keepPct a value between 0 and 100 inclusive. - */ - public void setKeepPct(int keepPct) { - this.keepPct = keepPct; - } - - /** - * Set to true to use map reduce to split the input - * - * @param useMapRed a boolean to indicate whether map reduce should be used - */ - public void setUseMapRed(boolean useMapRed) { - this.useMapRed = useMapRed; - } - - public void setMapRedOutputDirectory(Path mapRedOutputDirectory) { - this.mapRedOutputDirectory = mapRedOutputDirectory; - } - - public int getSplitLocation() { - return splitLocation; - } - - /** - * Set the location of the start of the test/training data split. Expressed as percentage of lines, for example - * 0 indicates that the test data should be taken from the start of the file, 100 indicates that the test data - * should be taken from the end of the input file, while 25 indicates that the test data should be taken from the - * first quarter of the file. - * <p/> - * This option is only relevant in cases where random selection is not employed - * - * @param splitLocation a value between 0 and 100 inclusive. - */ - public void setSplitLocation(int splitLocation) { - this.splitLocation = splitLocation; - } - - public Charset getCharset() { - return charset; - } - - /** - * Set the charset used to read and write files - */ - public void setCharset(Charset charset) { - this.charset = charset; - } - - public Path getInputDirectory() { - return inputDirectory; - } - - /** - * Set the directory from which input data will be read when the the {@link #splitDirectory()} method is invoked - */ - public void setInputDirectory(Path inputDir) { - this.inputDirectory = inputDir; - } - - public Path getTrainingOutputDirectory() { - return trainingOutputDirectory; - } - - /** - * Set the directory to which training data will be written. - */ - public void setTrainingOutputDirectory(Path trainingOutputDir) { - this.trainingOutputDirectory = trainingOutputDir; - } - - public Path getTestOutputDirectory() { - return testOutputDirectory; - } - - /** - * Set the directory to which test data will be written. - */ - public void setTestOutputDirectory(Path testOutputDir) { - this.testOutputDirectory = testOutputDir; - } - - public SplitCallback getCallback() { - return callback; - } - - /** - * Sets the callback used to inform the caller that an input file has been successfully split - */ - public void setCallback(SplitCallback callback) { - this.callback = callback; - } - - public int getTestRandomSelectionSize() { - return testRandomSelectionSize; - } - - /** - * Sets number of random input samples that will be saved to the test set. - */ - public void setTestRandomSelectionSize(int testRandomSelectionSize) { - this.testRandomSelectionSize = testRandomSelectionSize; - } - - public int getTestRandomSelectionPct() { - - return testRandomSelectionPct; - } - - /** - * Sets number of random input samples that will be saved to the test set as a percentage of the size of the - * input set. - * - * @param randomSelectionPct a value between 0 and 100 inclusive. - */ - public void setTestRandomSelectionPct(int randomSelectionPct) { - this.testRandomSelectionPct = randomSelectionPct; - } - - /** - * Validates that the current instance is in a consistent state - * - * @throws IllegalArgumentException if settings violate class invariants. - * @throws IOException if output directories do not exist or are not directories. - */ - public void validate() throws IOException { - Preconditions.checkArgument(testSplitSize >= 1 || testSplitSize == -1, - "Invalid testSplitSize: " + testSplitSize + ". Must be: testSplitSize >= 1 or testSplitSize = -1"); - Preconditions.checkArgument(splitLocation >= 0 && splitLocation <= 100 || splitLocation == -1, - "Invalid splitLocation percentage: " + splitLocation + ". Must be: 0 <= splitLocation <= 100 or splitLocation = -1"); - Preconditions.checkArgument(testSplitPct >= 0 && testSplitPct <= 100 || testSplitPct == -1, - "Invalid testSplitPct percentage: " + testSplitPct + ". Must be: 0 <= testSplitPct <= 100 or testSplitPct = -1"); - Preconditions.checkArgument(testRandomSelectionPct >= 0 && testRandomSelectionPct <= 100 - || testRandomSelectionPct == -1,"Invalid testRandomSelectionPct percentage: " + testRandomSelectionPct + - ". Must be: 0 <= testRandomSelectionPct <= 100 or testRandomSelectionPct = -1"); - - Preconditions.checkArgument(trainingOutputDirectory != null || useMapRed, - "No training output directory was specified"); - Preconditions.checkArgument(testOutputDirectory != null || useMapRed, "No test output directory was specified"); - - // only one of the following may be set, one must be set. - int count = 0; - if (testSplitSize > 0) { - count++; - } - if (testSplitPct > 0) { - count++; - } - if (testRandomSelectionSize > 0) { - count++; - } - if (testRandomSelectionPct > 0) { - count++; - } - - Preconditions.checkArgument(count == 1, "Exactly one of testSplitSize, testSplitPct, testRandomSelectionSize, " - + "testRandomSelectionPct should be set"); - - if (!useMapRed) { - Configuration conf = getConf(); - FileSystem fs = trainingOutputDirectory.getFileSystem(conf); - FileStatus trainingOutputDirStatus = fs.getFileStatus(trainingOutputDirectory); - Preconditions.checkArgument(trainingOutputDirStatus != null && trainingOutputDirStatus.isDir(), - "%s is not a directory", trainingOutputDirectory); - FileStatus testOutputDirStatus = fs.getFileStatus(testOutputDirectory); - Preconditions.checkArgument(testOutputDirStatus != null && testOutputDirStatus.isDir(), - "%s is not a directory", testOutputDirectory); - } - } - - /** - * Count the lines in the file specified as returned by {@code BufferedReader.readLine()} - * - * @param inputFile the file whose lines will be counted - * @param charset the charset of the file to read - * @return the number of lines in the input file. - * @throws IOException if there is a problem opening or reading the file. - */ - public static int countLines(FileSystem fs, Path inputFile, Charset charset) throws IOException { - int lineCount = 0; - try (BufferedReader reader = new BufferedReader(new InputStreamReader(fs.open(inputFile), charset))){ - while (reader.readLine() != null) { - lineCount++; - } - } - return lineCount; - } - - /** - * Used to pass information back to a caller once a file has been split without the need for a data object - */ - public interface SplitCallback { - void splitComplete(Path inputFile, int lineCount, int trainCount, int testCount, int testSplitStart); - } - -} http://git-wip-us.apache.org/repos/asf/mahout/blob/99a5358f/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java ---------------------------------------------------------------------- diff --git a/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java b/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java deleted file mode 100644 index 4a1ff86..0000000 --- a/integration/src/main/java/org/apache/mahout/utils/SplitInputJob.java +++ /dev/null @@ -1,213 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.mahout.utils; - -import java.io.IOException; -import java.io.Serializable; -import java.util.Random; - -import org.apache.hadoop.conf.Configuration; -import org.apache.hadoop.fs.FileSystem; -import org.apache.hadoop.fs.Path; -import org.apache.hadoop.io.Writable; -import org.apache.hadoop.io.WritableComparable; -import org.apache.hadoop.io.WritableComparator; -import org.apache.hadoop.mapreduce.Job; -import org.apache.hadoop.mapreduce.Mapper; -import org.apache.hadoop.mapreduce.Reducer; -import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; -import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; -import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; -import org.apache.hadoop.mapreduce.lib.output.MultipleOutputs; -import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; -import org.apache.mahout.common.Pair; -import org.apache.mahout.common.RandomUtils; -import org.apache.mahout.common.iterator.sequencefile.PathFilters; -import org.apache.mahout.common.iterator.sequencefile.PathType; -import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator; - -/** - * Class which implements a map reduce version of SplitInput. - * This class takes a SequenceFile input, e.g. a set of training data - * for a learning algorithm, downsamples it, applies a random - * permutation and splits it into test and training sets - */ -public final class SplitInputJob { - - private static final String DOWNSAMPLING_FACTOR = "SplitInputJob.downsamplingFactor"; - private static final String RANDOM_SELECTION_PCT = "SplitInputJob.randomSelectionPct"; - private static final String TRAINING_TAG = "training"; - private static final String TEST_TAG = "test"; - - private SplitInputJob() {} - - /** - * Run job to downsample, randomly permute and split data into test and - * training sets. This job takes a SequenceFile as input and outputs two - * SequenceFiles test-r-00000 and training-r-00000 which contain the test and - * training sets respectively - * - * @param initialConf - * Initial configuration - * @param inputPath - * path to input data SequenceFile - * @param outputPath - * path for output data SequenceFiles - * @param keepPct - * percentage of key value pairs in input to keep. The rest are - * discarded - * @param randomSelectionPercent - * percentage of key value pairs to allocate to test set. Remainder - * are allocated to training set - */ - @SuppressWarnings("rawtypes") - public static void run(Configuration initialConf, Path inputPath, - Path outputPath, int keepPct, float randomSelectionPercent) - throws IOException, ClassNotFoundException, InterruptedException { - - int downsamplingFactor = (int) (100.0 / keepPct); - initialConf.setInt(DOWNSAMPLING_FACTOR, downsamplingFactor); - initialConf.setFloat(RANDOM_SELECTION_PCT, randomSelectionPercent); - - // Determine class of keys and values - FileSystem fs = FileSystem.get(initialConf); - - SequenceFileDirIterator<? extends WritableComparable, Writable> iterator = - new SequenceFileDirIterator<>(inputPath, - PathType.LIST, PathFilters.partFilter(), null, false, fs.getConf()); - Class<? extends WritableComparable> keyClass; - Class<? extends Writable> valueClass; - if (iterator.hasNext()) { - Pair<? extends WritableComparable, Writable> pair = iterator.next(); - keyClass = pair.getFirst().getClass(); - valueClass = pair.getSecond().getClass(); - } else { - throw new IllegalStateException("Couldn't determine class of the input values"); - } - - Job job = new Job(new Configuration(initialConf)); - - MultipleOutputs.addNamedOutput(job, TRAINING_TAG, SequenceFileOutputFormat.class, keyClass, valueClass); - MultipleOutputs.addNamedOutput(job, TEST_TAG, SequenceFileOutputFormat.class, keyClass, valueClass); - job.setJarByClass(SplitInputJob.class); - FileInputFormat.addInputPath(job, inputPath); - FileOutputFormat.setOutputPath(job, outputPath); - job.setNumReduceTasks(1); - job.setInputFormatClass(SequenceFileInputFormat.class); - job.setOutputFormatClass(SequenceFileOutputFormat.class); - job.setMapperClass(SplitInputMapper.class); - job.setReducerClass(SplitInputReducer.class); - job.setSortComparatorClass(SplitInputComparator.class); - job.setOutputKeyClass(keyClass); - job.setOutputValueClass(valueClass); - job.submit(); - boolean succeeded = job.waitForCompletion(true); - if (!succeeded) { - throw new IllegalStateException("Job failed!"); - } - } - - /** Mapper which downsamples the input by downsamplingFactor */ - public static class SplitInputMapper extends - Mapper<WritableComparable<?>, Writable, WritableComparable<?>, Writable> { - - private int downsamplingFactor; - - @Override - public void setup(Context ctx) { - downsamplingFactor = ctx.getConfiguration().getInt(DOWNSAMPLING_FACTOR, 1); - } - - /** Only run map() for one out of every downsampleFactor inputs */ - @Override - public void run(Context context) throws IOException, InterruptedException { - setup(context); - int i = 0; - while (context.nextKeyValue()) { - if (i % downsamplingFactor == 0) { - map(context.getCurrentKey(), context.getCurrentValue(), context); - } - i++; - } - cleanup(context); - } - - } - - /** Reducer which uses MultipleOutputs to randomly allocate key value pairs between test and training outputs */ - public static class SplitInputReducer extends - Reducer<WritableComparable<?>, Writable, WritableComparable<?>, Writable> { - - private MultipleOutputs multipleOutputs; - private final Random rnd = RandomUtils.getRandom(); - private float randomSelectionPercent; - - @Override - protected void setup(Context ctx) throws IOException { - randomSelectionPercent = ctx.getConfiguration().getFloat(RANDOM_SELECTION_PCT, 0); - multipleOutputs = new MultipleOutputs(ctx); - } - - /** - * Randomly allocate key value pairs between test and training sets. - * randomSelectionPercent of the pairs will go to the test set. - */ - @Override - protected void reduce(WritableComparable<?> key, Iterable<Writable> values, - Context context) throws IOException, InterruptedException { - for (Writable value : values) { - if (rnd.nextInt(100) < randomSelectionPercent) { - multipleOutputs.write(TEST_TAG, key, value); - } else { - multipleOutputs.write(TRAINING_TAG, key, value); - } - } - - } - - @Override - protected void cleanup(Context context) throws IOException { - try { - multipleOutputs.close(); - } catch (InterruptedException e) { - throw new IOException(e); - } - } - - } - - /** Randomly permute key value pairs */ - public static class SplitInputComparator extends WritableComparator implements Serializable { - - private final Random rnd = RandomUtils.getRandom(); - - protected SplitInputComparator() { - super(WritableComparable.class); - } - - @Override - public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) { - if (rnd.nextBoolean()) { - return 1; - } else { - return -1; - } - } - } - -}
