Author: drew
Date: Sat Aug 14 00:30:47 2010
New Revision: 985413
URL: http://svn.apache.org/viewvc?rev=985413&view=rev
Log:
MAHOUT-451: Simple utility to split bayes input into training/test sets
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java
Added:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java?rev=985413&view=auto
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
(added)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
Sat Aug 14 00:30:47 2010
@@ -0,0 +1,597 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.bayes;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.charset.Charset;
+import java.util.BitSet;
+import java.util.Date;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+import org.apache.mahout.math.jet.random.sampling.RandomSampler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A utility for splitting files in the input format used by the Bayes
+ * classifiers into training and test sets in order to perform
cross-validation.
+ * This class is not strictly confined to working with the Bayes classifier
+ * input. It can be used for any input files where each line is a complete
+ * sample.
+ * <p>
+ * This class can be used to split directories of files or individual files
into
+ * training and test sets using a number of different methods.
+ * <p>
+ * When executed via {...@link #splitDirectory(File)} or {...@link
#splitFile(File)},
+ * the lines read from one or more, input files are written to files of the
same
+ * name into the directories specified by the
+ * {...@link #setTestOutputDirectory(File)} and
+ * {...@link #setTrainingOutputDirectory(File)} methods.
+ * <p>
+ * The composition of the test set is determined using one of the following
+ * approaches:
+ * <ul>
+ * <li>A contiguous set of items can be chosen from the input file(s) using the
+ * {...@link #setTestSplitSize(int)} or {...@link #setTestSplitPct(int)}
methods.
+ * {...@link #setTestSplitSize(int)} allocates a fixed number of items, while
+ * {...@link #setTestSplitPct(int)} allocates a percentage of the original
input,
+ * rounded up to the nearest integer. {...@link #setSplitLocation(int)} is
used to
+ * control the position in the input from which the test data is extracted and
+ * is described further below.</li>
+ * <li>A random sampling of items can be chosen from the input files(s) using
+ * the {...@link #setTestRandomSelectSize(int)} or
+ * {...@link #setTestRandomSelectionPct(int)} methods, each choosing a fixed
test
+ * set size or percentage of the input set size as described above. The
+ * {...@link org.apache.mahout.math.jet.random.sampling.RandomSampler
+ * RandomSampler} class from <code>mahout-math</code> is used to create a
sample
+ * of the appropriate size.</li>
+ * </ul>
+ * <p>
+ * Any one of the methods above can be used to control the size of the test
set.
+ * If multiple methods are called, a runtime exception will be thrown at
+ * execution time.
+ * <p>
+ * The {...@link #setSplitLocation(int)} method is passed an integer from 0 to
100
+ * (inclusive) which is translated into the position of the start of the test
+ * data within the input file.
+ * <p>
+ * Given:
+ * <ul>
+ * <li>an input file of 1500 lines</li>
+ * <li>a desired test data size of 10 percent</li>
+ * </ul>
+ * <p>
+ * <ul>
+ * <li>A split location of 0 will cause the first 150 items appearing in the
+ * input set to be written to the test set.</li>
+ * <li>A split location of 25 will cause items 375-525 to be written to the
test
+ * set.</li>
+ * <li>A split location of 100 will cause the last 150 items in the input to be
+ * written to the test set</li>
+ * </ul>
+ * The start of the split will always be adjusted forwards in order to ensure
+ * that the desired test set size is allocated. Split location has no effect is
+ * random sampling is employed.
+ */
+public class SplitBayesInput {
+
+ private static final Logger log =
LoggerFactory.getLogger(SplitBayesInput.class);
+
+ private int testSplitSize = -1;
+ private int testSplitPct = -1;
+ private int splitLocation = 100;
+ private int testRandomSelectionSize = -1;
+ private int testRandomSelectionPct = -1;
+ private Charset charset = Charset.forName("UTF-8");
+
+ private File inputDirectory;
+ private File trainingOutputDirectory;
+ private File testOutputDirectory;
+
+ private SplitCallback callback;
+
+ public static void main(String[] args) throws Exception {
+ SplitBayesInput si = new SplitBayesInput();
+ if (si.parseArgs(args)) {
+ si.splitDirectory();
+ }
+ }
+
+ /** Configure this instance based on the command-line arguments contained
within provided array.
+ * Calls {...@link #validate()} to ensure consistency of configuration.
+ *
+ * @param args.
+ * @return true if the arguments were parsed successfully and execution
should proceed.
+ * @throws Exception if there is a problem parsing the command-line
arguments or the particular
+ * combination would violate class invariants.
+ */
+ public boolean parseArgs(String[] args) throws Exception {
+
+ DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+ ArgumentBuilder abuilder = new ArgumentBuilder();
+ GroupBuilder gbuilder = new GroupBuilder();
+ Option helpOpt = DefaultOptionCreator.helpOption();
+
+ Option inputDirOpt =
obuilder.withLongName("inputDir").withRequired(true).withArgument(
+
abuilder.withName("inputDir").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The input directory").withShortName("i").create();
+
+ Option trainingOutputDirOpt =
obuilder.withLongName("trainingOutputDir").withRequired(true).withArgument(
+
abuilder.withName("outputDir").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The training data output directory").withShortName("tr").create();
+
+ Option testOutputDirOpt =
obuilder.withLongName("testOutputDir").withRequired(true).withArgument(
+
abuilder.withName("outputDir").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The test data output directory").withShortName("te").create();
+
+ Option testSplitSizeOpt =
obuilder.withLongName("testSplitSize").withRequired(false).withArgument(
+
abuilder.withName("splitSize").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of documents held back as test data for each
category").withShortName("ss").create();
+
+ Option testSplitPctOpt =
obuilder.withLongName("testSplitPct").withRequired(false).withArgument(
+
abuilder.withName("splitPct").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The percentage of documents held back as test data for each
category").withShortName("sp").create();
+
+ Option splitLocationOpt =
obuilder.withLongName("splitLocation").withRequired(false).withArgument(
+
abuilder.withName("splitLoc").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Location for start of test data expressed as a percentage of the
input file size (0=start, 50=middle, 100=end")
+ .withShortName("sl").create();
+
+ Option randomSelectionSizeOpt =
obuilder.withLongName("randomSelectionSize").withRequired(false).withArgument(
+
abuilder.withName("randomSize").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The number of itemr to be randomly selected as test data
").withShortName("rs").create();
+
+ Option randomSelectionPctOpt =
obuilder.withLongName("randomSelectionPct").withRequired(false).withArgument(
+
abuilder.withName("randomPct").withMinimum(1).withMaximum(1).create()).withDescription(
+ "Percentage of items to be randomly selected as test data
").withShortName("rp").create();
+
+ Option charsetOpt =
obuilder.withLongName("charset").withRequired(true).withArgument(
+
abuilder.withName("charset").withMinimum(1).withMaximum(1).create()).withDescription(
+ "The name of the character encoding of the input
files").withShortName("c").create();
+
+ Group group =
gbuilder.withName("Options").withOption(inputDirOpt).withOption(trainingOutputDirOpt)
+
.withOption(testOutputDirOpt).withOption(testSplitSizeOpt).withOption(testSplitPctOpt)
+
.withOption(splitLocationOpt).withOption(randomSelectionSizeOpt).withOption(randomSelectionPctOpt)
+ .withOption(charsetOpt).create();
+
+ try {
+
+ Parser parser = new Parser();
+ parser.setGroup(group);
+ CommandLine cmdLine = parser.parse(args);
+
+ if (cmdLine.hasOption(helpOpt)) {
+ CommandLineUtil.printHelp(group);
+ return false;
+ }
+
+ inputDirectory = new File((String) cmdLine.getValue(inputDirOpt));
+ trainingOutputDirectory = new File((String)
cmdLine.getValue(trainingOutputDirOpt));
+ testOutputDirectory = new File((String)
cmdLine.getValue(testOutputDirOpt));
+
+ charset = Charset.forName((String) cmdLine.getValue(charsetOpt));
+
+ if (cmdLine.hasOption(testSplitSizeOpt) &&
cmdLine.hasOption(testSplitPctOpt)) {
+ throw new OptionException(testSplitSizeOpt, "must have either split
size or split percentage option, not BOTH");
+ }
+ else if (!cmdLine.hasOption(testSplitSizeOpt) &&
!cmdLine.hasOption(testSplitPctOpt)) {
+ throw new OptionException(testSplitSizeOpt, "must have either split
size or split percentage option");
+ }
+
+ if (cmdLine.hasOption(testSplitSizeOpt)) {
+ setTestSplitSize(Integer.parseInt((String)
cmdLine.getValue(testSplitSizeOpt)));
+ }
+
+ if (cmdLine.hasOption(testSplitPctOpt)) {
+ setTestSplitPct(Integer.parseInt((String)
cmdLine.getValue(testSplitPctOpt)));
+ }
+
+ if (cmdLine.hasOption(splitLocationOpt)) {
+ setSplitLocation(Integer.parseInt((String)
cmdLine.getValue(splitLocationOpt)));
+ }
+
+ if (cmdLine.hasOption(randomSelectionSizeOpt)) {
+ setTestRandomSelectionSize(Integer.parseInt((String)
cmdLine.getValue(randomSelectionSizeOpt)));
+ }
+
+ if (cmdLine.hasOption(randomSelectionPctOpt)) {
+ setTestRandomSelectionPct(Integer.parseInt((String)
cmdLine.getValue(randomSelectionPctOpt)));
+ }
+
+ trainingOutputDirectory.mkdirs();
+ testOutputDirectory.mkdirs();
+
+ } catch (OptionException e) {
+ log.error("Command-line option Exception", e);
+ CommandLineUtil.printHelp(group);
+ return false;
+ }
+
+ validate();
+ return true;
+ }
+
+ /** Perform a split on directory specified by {...@link
#setInputDirectory(File)} by calling {...@link #splitFile(File)}
+ * on each file found within that directory.
+ *
+ * @param inputDir
+ * @throws IOException
+ */
+ public void splitDirectory() throws IOException {
+ this.splitDirectory(inputDirectory);
+ }
+
+ /** Perform a split on the specified directory by calling {...@link
#splitFile(File)} on each file found within that
+ * directory.
+ *
+ * @param inputDir
+ * @throws IOException
+ */
+ public void splitDirectory(File inputDir) throws IOException {
+ if (!inputDir.isDirectory())
+ throw new IOException(inputDir + " does not exist, or is not a
directory");
+
+ // input dir contains one file per category.
+ File[] inputFiles = inputDir.listFiles();
+ for (File inputFile : inputFiles) {
+ if (inputFile.isFile())
+ splitFile(inputFile);
+ }
+ }
+
+ /** Perform a split on the specified input file. Results will be written to
files of the same name in the specified
+ * training and test output directories. The {...@link #validate()} method
is called prior to executing the split.
+ *
+ * @param inputFile
+ * @throws IOException
+ */
+ public void splitFile(File inputFile) throws IOException {
+ if (!inputFile.isFile())
+ throw new IOException(inputFile + " does not exist, or is not a file");
+
+ validate();
+
+ File testOutputFile = new File(testOutputDirectory, inputFile.getName());
+ File trainingOutputFile = new File(trainingOutputDirectory,
inputFile.getName());
+
+ int lineCount = countLines(inputFile, charset);
+
+ log.info(inputFile.getName() + " has " + lineCount + " lines");
+
+ int testSplitStart = 0;
+ int testSplitSize = this.testSplitSize; // don't modify state
+ BitSet randomSel = null;
+
+ if (testRandomSelectionPct > 0 || testRandomSelectionSize > 0) {
+ testSplitSize = this.testRandomSelectionSize;
+
+ if (testRandomSelectionPct > 0) {
+ testSplitSize = Math.round(lineCount * ( testRandomSelectionPct /
100.0f ));
+ }
+ log.info(inputFile.getName() + " test split size is " + testSplitSize +
+ " based on random selection percentage " + testRandomSelectionPct);
+ long[] ridx = new long[testSplitSize];
+ RandomSampler.sample(testSplitSize, lineCount-1, testSplitSize, 0, ridx,
0, new MersenneTwister(new Date()));
+ randomSel = new BitSet(lineCount);
+ for (int i=0; i < ridx.length; i++) {
+ randomSel.set((int) ridx[i] + 1);
+ }
+ }
+ else {
+ if (testSplitPct > 0) { // calculate split size based on percentage
+ testSplitSize = Math.round(lineCount * ( testSplitPct / 100.0f ));
+ log.info(inputFile.getName() + " test split size is " + testSplitSize
+
+ " based on percentage " + testSplitPct);
+ }
+ else {
+ log.info(inputFile.getName() + " test split size is " + testSplitSize);
+ }
+
+ if (splitLocation > 0) { // calculate start of split based on percentage
+ testSplitStart = Math.round(lineCount * ( splitLocation / 100.0f ));
+ if (lineCount - testSplitStart < testSplitSize) {
+ // adjust split start downwards based on split size.
+ testSplitStart = lineCount - testSplitSize;
+ }
+ log.info(inputFile.getName() + " test split start is " +
testSplitStart +
+ " based on split location " + splitLocation);
+ }
+
+ if (testSplitStart < 0) {
+ throw new IllegalArgumentException("test split size for " + inputFile
+ " is too large, it would produce an "
+ + "empty training set from the initial set of " + lineCount + "
examples");
+ }
+ else if ((lineCount - testSplitSize) < testSplitSize) {
+ log.warn("CAUTION: Test set size for " + inputFile + " may be too
large, " + testSplitSize +
+ " is larger than the number of lines remaining in the training
set: " + (lineCount - testSplitSize));
+ }
+ }
+ BufferedReader reader = new BufferedReader(new InputStreamReader(new
FileInputStream(inputFile), charset));
+ Writer trainingWriter = new OutputStreamWriter(new
FileOutputStream(trainingOutputFile), charset);
+ Writer testWriter = new OutputStreamWriter(new
FileOutputStream(testOutputFile), charset);
+ Writer writer;
+
+ int pos = 0;
+ int trainCount = 0;
+ int testCount = 0;
+
+ String line;
+ while ((line = reader.readLine()) != null) {
+ pos++;
+
+ if (testRandomSelectionPct > 0) { // Randomly choose
+ writer = randomSel.get(pos) ? testWriter : trainingWriter;
+ }
+ else { // Choose based on location
+ writer = pos > testSplitStart ? testWriter : trainingWriter;
+ }
+
+ if (writer == testWriter) {
+ if (testCount >= testSplitSize)
+ writer = trainingWriter;
+ else
+ testCount++;
+ }
+
+ if (writer == trainingWriter) {
+ trainCount++;
+ }
+
+ writer.write(line);
+ writer.write('\n');
+ }
+
+ IOUtils.quietClose(trainingWriter);
+ IOUtils.quietClose(testWriter);
+
+ log.info("file: " + inputFile.getName()+ ", input: " + lineCount + "
train: " + trainCount + ", test: " +
+ testCount + " starting at " + testSplitStart);
+
+ // testing;
+ if (callback != null) {
+ callback.splitComplete(inputFile, lineCount, trainCount, testCount,
testSplitStart);
+ }
+ }
+
+ public int getTestSplitSize() {
+ return testSplitSize;
+ }
+
+ public void setTestSplitSize(int testSplitSize) {
+ this.testSplitSize = testSplitSize;
+ }
+
+ public int getTestSplitPct() {
+ return testSplitPct;
+ }
+
+ /** Sets the percentage of the input data to allocate to the test split
+ *
+ * @param testSplitPct
+ * a value between 0 and 100 inclusive.
+ */
+ public void setTestSplitPct(int testSplitPct) {
+ this.testSplitPct = testSplitPct;
+ }
+
+ public int getSplitLocation() {
+ return splitLocation;
+ }
+
+ /** Set the location of the start of the test/training data split. Expressed
as percentage of lines, for example
+ * 0 indicates that the test data should be taken from the start of the
file, 100 indicates that the test data
+ * should be taken from the end of the input file, while 25 indicates that
the test data should be taken from the
+ * first quarter of the file.
+ * <p>
+ * This option is only relevant in cases where random selection is not
employed
+ *
+ * @param splitLocation
+ * a value between 0 and 100 inclusive.
+ */
+ public void setSplitLocation(int splitLocation) {
+ this.splitLocation = splitLocation;
+ }
+
+ public Charset getCharset() {
+ return charset;
+ }
+
+ /** Set the charset used to read and write files
+ *
+ * @param charset
+ */
+ public void setCharset(Charset charset) {
+ this.charset = charset;
+ }
+
+ public File getInputDirectory() {
+ return inputDirectory;
+ }
+
+ /** Set the directory from which input data will be read when the the
{...@link #splitDirectory()} method is invoked
+ *
+ * @param inputDir
+ */
+ public void setInputDirectory(File inputDir) {
+ this.inputDirectory = inputDir;
+ }
+
+ public File getTrainingOutputDirectory() {
+ return trainingOutputDirectory;
+ }
+
+ /** Set the directory to which training data will be written.
+ *
+ * @param trainingOutputDir
+ */
+ public void setTrainingOutputDirectory(File trainingOutputDir) {
+ this.trainingOutputDirectory = trainingOutputDir;
+ }
+
+ public File getTestOutputDirectory() {
+ return testOutputDirectory;
+ }
+
+ /** Set the directory to which test data will be written.
+ *
+ * @param testOutputDir
+ */
+ public void setTestOutputDirectory(File testOutputDir) {
+ this.testOutputDirectory = testOutputDir;
+ }
+
+ public SplitCallback getCallback() {
+ return callback;
+ }
+
+ /** Sets the callback used to inform the caller that an input file has been
successfully split
+ *
+ * @param callback
+ */
+ public void setCallback(SplitCallback callback) {
+ this.callback = callback;
+ }
+
+ public int getTestRandomSelectionSize() {
+ return testRandomSelectionSize;
+ }
+
+ /** Sets number of random input samples that will be saved to the test set.
+ *
+ * @param testRandomSelectionSize
+ */
+ public void setTestRandomSelectionSize(int testRandomSelectionSize) {
+ this.testRandomSelectionSize = testRandomSelectionSize;
+ }
+
+ public int getTestRandomSelectionPct() {
+
+ return testRandomSelectionPct;
+ }
+
+ /** Sets number of random input samples that will be saved to the test set
as a percentage of the size of the
+ * input set.
+ *
+ * @param testRandomSelectionPct
+ * a value between 0 and 100 inclusive.
+ */
+ public void setTestRandomSelectionPct(int randomSelectionPct) {
+ this.testRandomSelectionPct = randomSelectionPct;
+ }
+
+ /** Validates that the current instance is in a consistent state
+ *
+ * @throws IllegalStateException
+ * if settings violate class invariants.
+ *
+ * @throws IOException
+ * if output directories do not exist or are not directories.
+ *
+ * @throws NullPointerException
+ * if output directories are not set.
+ */
+ public void validate() throws IOException {
+ if ((testSplitSize < 1) && (testSplitSize != -1))
+ throw new IllegalStateException("test split size must be 1 or greater");
+
+ if ((splitLocation < 0 || splitLocation > 100) && (splitLocation != -1))
+ throw new IllegalStateException("test split percentage must be between 0
and 100");
+
+ if ((testSplitPct < 0 || testSplitPct > 100) && (testSplitPct != -1))
+ throw new IllegalStateException("test split percentage must be between 0
and 100");
+
+ if ((splitLocation < 0 || splitLocation > 100) && (splitLocation != -1))
+ throw new IllegalStateException("test split percentage must be between 0
and 100");
+
+ if ((testRandomSelectionPct < 0 || testRandomSelectionPct > 100) &&
(testRandomSelectionPct != -1))
+ throw new IllegalStateException("test split percentage must be between 0
and 100");
+
+ // only one of the following may be set, one must be set.
+ int count = 0;
+ if (testSplitSize > 0) count++;
+ if (testSplitPct > 0) count++;
+ if (testRandomSelectionSize > 0) count++;
+ if (testRandomSelectionPct > 0) count++;
+
+ if (count == 0) {
+ throw new IllegalStateException("either test split size, test split pct,
" +
+ "random selection size or random selection pct must be specified and a
positive integer");
+ }
+ else if (count > 1) {
+ throw new IllegalStateException("only test split size, test split pct, "
+
+ "random selection size or random selection pct may be specified");
+ }
+
+ if (trainingOutputDirectory == null)
+ throw new NullPointerException("no training output directory was
specified");
+
+ if (testOutputDirectory == null)
+ throw new NullPointerException("no test output directory was specified");
+
+ if (!trainingOutputDirectory.isDirectory())
+ throw new IOException(inputDirectory + " does not exist, or is not a
directory");
+
+ if (!testOutputDirectory.isDirectory())
+ throw new IOException(inputDirectory + " does not exist, or is not a
directory");
+ }
+
+ /** Count the lines in the file specified as returned by
<code>BufferedReader.readLine()</code>
+ *
+ * @param inputFile
+ * the file whose lines will be counted
+ *
+ * @param charset
+ * the charset of the file to read
+ *
+ * @return the number of lines in the input file.
+ *
+ * @throws IOException
+ * if there is a problem opening or reading the file.
+ */
+ public static int countLines(File inputFile, Charset charset) throws
IOException {
+ BufferedReader countReader = new BufferedReader(new InputStreamReader(new
FileInputStream(inputFile), charset));
+ int lineCount = 0;
+ while (countReader.readLine() != null) lineCount++;
+ IOUtils.quietClose(countReader);
+
+ return lineCount;
+ }
+
+ /** Used to pass information back to a caller once a file has been split
without the need for a data object */
+ public static interface SplitCallback {
+ public void splitComplete(File inputFile, int lineCount, int trainCount,
int testCount, int testSplitStart);
+ }
+}
Added:
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java?rev=985413&view=auto
==============================================================================
---
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java
(added)
+++
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java
Sat Aug 14 00:30:47 2010
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.bayes;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.charset.Charset;
+
+import junit.framework.TestCase;
+
+import org.apache.mahout.classifier.ClassifierData;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+
+public class SplitBayesInputTest extends MahoutTestCase {
+
+ OpenObjectIntHashMap<String> countMap;
+
+ Charset charset;
+ File tempInputFile;
+ File tempTrainingDirectory;
+ File tempTestDirectory;
+ File tempInputDirectory;
+
+ SplitBayesInput si;
+
+ public void setUp() throws Exception {
+ super.setUp();
+
+ countMap = new OpenObjectIntHashMap<String>();
+
+ charset = Charset.forName("UTF-8");
+ tempInputFile = getTestTempFile("bayesinputfile");
+ tempTrainingDirectory = getTestTempDir("bayestrain");
+ tempTestDirectory = getTestTempDir("bayestest");
+ tempInputDirectory = getTestTempDir("bayesinputdir");
+
+ si = new SplitBayesInput();
+ si.setTrainingOutputDirectory(tempTrainingDirectory);
+ si.setTestOutputDirectory(tempTestDirectory);
+ si.setInputDirectory(tempInputDirectory);
+ }
+
+ public void writeMultipleInputFiles() throws IOException {
+ Writer writer = null;
+ String currentLabel = null;
+
+ for (String[] entry : ClassifierData.DATA) {
+ if (!entry[0].equals(currentLabel)) {
+ currentLabel = entry[0];
+ if (writer != null) IOUtils.quietClose(writer);
+ writer = new BufferedWriter(
+ new OutputStreamWriter(
+ new FileOutputStream(new File(tempInputDirectory,
currentLabel)), Charset.forName("UTF-8")));
+ }
+ countMap.adjustOrPutValue(currentLabel, 1, 1);
+ writer.write(currentLabel + '\t' + entry[1] + '\n');
+ }
+ IOUtils.quietClose(writer);
+ }
+
+ public void writeSingleInputFile() throws IOException {
+ BufferedWriter writer = new BufferedWriter(
+ new OutputStreamWriter(new FileOutputStream(tempInputFile),
Charset.forName("UTF-8")));
+ for (String[] entry : ClassifierData.DATA) {
+ writer.write(entry[0] + '\t' + entry[1] + '\n');
+ }
+ writer.close();
+ }
+
+ public void testSplitDirectory() throws Exception {
+ final int testSplitSize = 1;
+
+ writeMultipleInputFiles();
+
+ si.setTestSplitSize(testSplitSize);
+ si.setCallback(new SplitBayesInput.SplitCallback() {
+ @Override
+ public void splitComplete(File inputFile, int lineCount, int
trainCount, int testCount, int testSplitStart) {
+ int trainingLines = countMap.get(inputFile.getName()) -
testSplitSize;
+ try {
+ assertSplit(inputFile, charset, testSplitSize, trainingLines,
tempTrainingDirectory, tempTestDirectory);
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ });
+
+ si.splitDirectory(tempInputDirectory);
+ }
+
+ public void testSplitFile() throws Exception {
+ writeSingleInputFile();
+ si.setTestSplitSize(2);
+ si.setCallback(new TestCallback(2, 10));
+ si.splitFile(tempInputFile);
+ }
+
+ public void testSplitFileLocation() throws Exception {
+ writeSingleInputFile();
+ si.setTestSplitSize(2);
+ si.setSplitLocation(50);
+ si.setCallback(new TestCallback(2, 10));
+ si.splitFile(tempInputFile);
+ }
+
+ public void testSplitFilePct() throws Exception {
+ writeSingleInputFile();
+ si.setTestSplitPct(25);
+
+ si.setCallback(new TestCallback(3, 9));
+ si.splitFile(tempInputFile);
+ }
+
+ public void testSplitFilePctLocation() throws Exception {
+ writeSingleInputFile();
+ si.setTestSplitPct(25);
+ si.setSplitLocation(50);
+ si.setCallback(new TestCallback(3, 9));
+ si.splitFile(tempInputFile);
+ }
+
+ public void testSplitFileRandomSelectionSize() throws Exception {
+ writeSingleInputFile();
+ si.setTestRandomSelectionSize(5);
+
+ si.setCallback(new TestCallback(5, 7));
+ si.splitFile(tempInputFile);
+ }
+
+ public void testSplitFileRandomSelectionPct() throws Exception {
+ writeSingleInputFile();
+ si.setTestRandomSelectionPct(25);
+
+ si.setCallback(new TestCallback(3, 9));
+ si.splitFile(tempInputFile);
+ }
+
+ public void testValidate() throws Exception {
+ SplitBayesInput st = new SplitBayesInput();
+ assertValidateException(st, IllegalStateException.class);
+
+ st.setTestSplitSize(100);
+ assertValidateException(st, NullPointerException.class);
+
+ st.setTestOutputDirectory(tempTestDirectory);
+ assertValidateException(st, NullPointerException.class);
+
+ st.setTrainingOutputDirectory(tempTrainingDirectory);
+ st.validate();
+
+ st.setTestSplitPct(50);
+ assertValidateException(st, IllegalStateException.class);
+
+ st = new SplitBayesInput();
+ st.setTestRandomSelectionPct(50);
+ st.setTestOutputDirectory(tempTestDirectory);
+ st.setTrainingOutputDirectory(tempTrainingDirectory);
+ st.validate();
+
+ st.setTestSplitPct(50);
+ assertValidateException(st, IllegalStateException.class);
+
+ st = new SplitBayesInput();
+ st.setTestRandomSelectionPct(50);
+ st.setTestOutputDirectory(tempTestDirectory);
+ st.setTrainingOutputDirectory(tempTrainingDirectory);
+ st.validate();
+
+ st.setTestSplitSize(100);
+ assertValidateException(st, IllegalStateException.class);
+ }
+
+ private class TestCallback implements SplitBayesInput.SplitCallback {
+ int testSplitSize;
+ int trainingLines;
+
+ public TestCallback(int testSplitSize, int trainingLines) {
+ this.testSplitSize = testSplitSize;
+ this.trainingLines = trainingLines;
+ }
+
+ @Override
+ public void splitComplete(File inputFile, int lineCount, int trainCount,
int testCount, int testSplitStart) {
+ try {
+ assertSplit(tempInputFile, charset, testSplitSize, trainingLines,
tempTrainingDirectory, tempTestDirectory);
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+ }
+
+ private void assertValidateException(SplitBayesInput st, Class<?> clazz) {
+ try {
+ st.validate();
+ TestCase.fail("Expected valdate() to throw an exception, received none");
+ }
+ catch (Exception e) {
+ if (!e.getClass().isAssignableFrom(clazz)) {
+ e.printStackTrace();
+ TestCase.fail("Unexpected exception. Expected " + clazz.getName() + "
received " + e.getClass().getName());
+ }
+ }
+ }
+
+ private void assertSplit(File tempInputFile, Charset charset, int
testSplitSize, int trainingLines,
+ File tempTrainingDirectory, File tempTestDirectory) throws Exception {
+
+ File testFile = new File(tempTestDirectory, tempInputFile.getName());
+ TestCase.assertTrue("test file exists", testFile.isFile());
+ TestCase.assertEquals("test line count", testSplitSize,
SplitBayesInput.countLines(testFile,charset));
+
+ File trainingFile = new File(tempTrainingDirectory,
tempInputFile.getName());
+ TestCase.assertTrue("training file exists", trainingFile.isFile());
+ TestCase.assertEquals("training line count", trainingLines,
SplitBayesInput.countLines(trainingFile, charset));
+ }
+}