Author: drew
Date: Sat Aug 14 00:30:47 2010
New Revision: 985413

URL: http://svn.apache.org/viewvc?rev=985413&view=rev
Log:
MAHOUT-451: Simple utility to split bayes input into training/test sets

Added:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
    mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/
    mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/
    
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java?rev=985413&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/bayes/SplitBayesInput.java
 Sat Aug 14 00:30:47 2010
@@ -0,0 +1,597 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.bayes;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.charset.Charset;
+import java.util.BitSet;
+import java.util.Date;
+
+import org.apache.commons.cli2.CommandLine;
+import org.apache.commons.cli2.Group;
+import org.apache.commons.cli2.Option;
+import org.apache.commons.cli2.OptionException;
+import org.apache.commons.cli2.builder.ArgumentBuilder;
+import org.apache.commons.cli2.builder.DefaultOptionBuilder;
+import org.apache.commons.cli2.builder.GroupBuilder;
+import org.apache.commons.cli2.commandline.Parser;
+import org.apache.mahout.common.CommandLineUtil;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.commandline.DefaultOptionCreator;
+import org.apache.mahout.math.jet.random.engine.MersenneTwister;
+import org.apache.mahout.math.jet.random.sampling.RandomSampler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * A utility for splitting files in the input format used by the Bayes
+ * classifiers into training and test sets in order to perform 
cross-validation.
+ * This class is not strictly confined to working with the Bayes classifier
+ * input. It can be used for any input files where each line is a complete
+ * sample.
+ * <p>
+ * This class can be used to split directories of files or individual files 
into
+ * training and test sets using a number of different methods.
+ * <p>
+ * When executed via {...@link #splitDirectory(File)} or {...@link 
#splitFile(File)},
+ * the lines read from one or more, input files are written to files of the 
same
+ * name into the directories specified by the
+ * {...@link #setTestOutputDirectory(File)} and
+ * {...@link #setTrainingOutputDirectory(File)} methods.
+ * <p>
+ * The composition of the test set is determined using one of the following
+ * approaches:
+ * <ul>
+ * <li>A contiguous set of items can be chosen from the input file(s) using the
+ * {...@link #setTestSplitSize(int)} or {...@link #setTestSplitPct(int)} 
methods.
+ * {...@link #setTestSplitSize(int)} allocates a fixed number of items, while
+ * {...@link #setTestSplitPct(int)} allocates a percentage of the original 
input,
+ * rounded up to the nearest integer. {...@link #setSplitLocation(int)} is 
used to
+ * control the position in the input from which the test data is extracted and
+ * is described further below.</li>
+ * <li>A random sampling of items can be chosen from the input files(s) using
+ * the {...@link #setTestRandomSelectSize(int)} or
+ * {...@link #setTestRandomSelectionPct(int)} methods, each choosing a fixed 
test
+ * set size or percentage of the input set size as described above. The
+ * {...@link org.apache.mahout.math.jet.random.sampling.RandomSampler
+ * RandomSampler} class from <code>mahout-math</code> is used to create a 
sample
+ * of the appropriate size.</li>
+ * </ul>
+ * <p>
+ * Any one of the methods above can be used to control the size of the test 
set.
+ * If multiple methods are called, a runtime exception will be thrown at
+ * execution time.
+ * <p>
+ * The {...@link #setSplitLocation(int)} method is passed an integer from 0 to 
100
+ * (inclusive) which is translated into the position of the start of the test
+ * data within the input file.
+ * <p>
+ * Given:
+ * <ul>
+ * <li>an input file of 1500 lines</li>
+ * <li>a desired test data size of 10 percent</li>
+ * </ul>
+ * <p>
+ * <ul>
+ * <li>A split location of 0 will cause the first 150 items appearing in the
+ * input set to be written to the test set.</li>
+ * <li>A split location of 25 will cause items 375-525 to be written to the 
test
+ * set.</li>
+ * <li>A split location of 100 will cause the last 150 items in the input to be
+ * written to the test set</li>
+ * </ul>
+ * The start of the split will always be adjusted forwards in order to ensure
+ * that the desired test set size is allocated. Split location has no effect is
+ * random sampling is employed.
+ */
+public class SplitBayesInput {
+  
+  private static final Logger log = 
LoggerFactory.getLogger(SplitBayesInput.class);
+
+  private int testSplitSize = -1;
+  private int testSplitPct  = -1;
+  private int splitLocation = 100;
+  private int testRandomSelectionSize = -1;
+  private int testRandomSelectionPct = -1;
+  private Charset charset = Charset.forName("UTF-8");
+  
+  private File inputDirectory;
+  private File trainingOutputDirectory;
+  private File testOutputDirectory;
+  
+  private SplitCallback callback;
+  
+  public static void main(String[] args) throws Exception {
+    SplitBayesInput si = new SplitBayesInput();
+    if (si.parseArgs(args)) {
+      si.splitDirectory();
+    }
+  }
+  
+  /** Configure this instance based on the command-line arguments contained 
within provided array. 
+   * Calls {...@link #validate()} to ensure consistency of configuration.
+   * 
+   * @param args.
+   * @return true if the arguments were parsed successfully and execution 
should proceed.
+   * @throws Exception if there is a problem parsing the command-line 
arguments or the particular
+   *   combination would violate class invariants.
+   */
+  public boolean parseArgs(String[] args) throws Exception {
+
+    DefaultOptionBuilder obuilder = new DefaultOptionBuilder();
+    ArgumentBuilder abuilder = new ArgumentBuilder();
+    GroupBuilder gbuilder = new GroupBuilder();
+    Option helpOpt = DefaultOptionCreator.helpOption();
+    
+    Option inputDirOpt = 
obuilder.withLongName("inputDir").withRequired(true).withArgument(
+        
abuilder.withName("inputDir").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The input directory").withShortName("i").create();
+    
+    Option trainingOutputDirOpt = 
obuilder.withLongName("trainingOutputDir").withRequired(true).withArgument(
+        
abuilder.withName("outputDir").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The training data output directory").withShortName("tr").create();
+    
+    Option testOutputDirOpt = 
obuilder.withLongName("testOutputDir").withRequired(true).withArgument(
+        
abuilder.withName("outputDir").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The test data output directory").withShortName("te").create();
+    
+    Option testSplitSizeOpt = 
obuilder.withLongName("testSplitSize").withRequired(false).withArgument(
+        
abuilder.withName("splitSize").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The number of documents held back as test data for each 
category").withShortName("ss").create();
+    
+    Option testSplitPctOpt = 
obuilder.withLongName("testSplitPct").withRequired(false).withArgument(
+        
abuilder.withName("splitPct").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The percentage of documents held back as test data for each 
category").withShortName("sp").create();
+    
+    Option splitLocationOpt = 
obuilder.withLongName("splitLocation").withRequired(false).withArgument(
+        
abuilder.withName("splitLoc").withMinimum(1).withMaximum(1).create()).withDescription(
+        "Location for start of test data expressed as a percentage of the 
input file size (0=start, 50=middle, 100=end")
+        .withShortName("sl").create();
+    
+    Option randomSelectionSizeOpt = 
obuilder.withLongName("randomSelectionSize").withRequired(false).withArgument(
+        
abuilder.withName("randomSize").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The number of itemr to be randomly selected as test data 
").withShortName("rs").create();
+    
+    Option randomSelectionPctOpt = 
obuilder.withLongName("randomSelectionPct").withRequired(false).withArgument(
+        
abuilder.withName("randomPct").withMinimum(1).withMaximum(1).create()).withDescription(
+        "Percentage of items to be randomly selected as test data 
").withShortName("rp").create();
+    
+    Option charsetOpt = 
obuilder.withLongName("charset").withRequired(true).withArgument(
+        
abuilder.withName("charset").withMinimum(1).withMaximum(1).create()).withDescription(
+        "The name of the character encoding of the input 
files").withShortName("c").create();
+    
+    Group group = 
gbuilder.withName("Options").withOption(inputDirOpt).withOption(trainingOutputDirOpt)
+         
.withOption(testOutputDirOpt).withOption(testSplitSizeOpt).withOption(testSplitPctOpt)
+         
.withOption(splitLocationOpt).withOption(randomSelectionSizeOpt).withOption(randomSelectionPctOpt)
+         .withOption(charsetOpt).create();
+    
+    try {
+      
+      Parser parser = new Parser();
+      parser.setGroup(group);
+      CommandLine cmdLine = parser.parse(args);
+      
+      if (cmdLine.hasOption(helpOpt)) {
+        CommandLineUtil.printHelp(group);
+        return false;
+      }
+      
+      inputDirectory = new File((String) cmdLine.getValue(inputDirOpt));
+      trainingOutputDirectory = new File((String) 
cmdLine.getValue(trainingOutputDirOpt));
+      testOutputDirectory = new File((String) 
cmdLine.getValue(testOutputDirOpt));
+     
+      charset = Charset.forName((String) cmdLine.getValue(charsetOpt));
+
+      if (cmdLine.hasOption(testSplitSizeOpt) && 
cmdLine.hasOption(testSplitPctOpt)) {
+        throw new OptionException(testSplitSizeOpt, "must have either split 
size or split percentage option, not BOTH");
+      }
+      else if (!cmdLine.hasOption(testSplitSizeOpt) && 
!cmdLine.hasOption(testSplitPctOpt)) {
+        throw new OptionException(testSplitSizeOpt, "must have either split 
size or split percentage option");
+      }
+
+      if (cmdLine.hasOption(testSplitSizeOpt)) {
+        setTestSplitSize(Integer.parseInt((String) 
cmdLine.getValue(testSplitSizeOpt)));
+      }
+      
+      if (cmdLine.hasOption(testSplitPctOpt)) {
+        setTestSplitPct(Integer.parseInt((String) 
cmdLine.getValue(testSplitPctOpt)));
+      }
+      
+      if (cmdLine.hasOption(splitLocationOpt)) {
+        setSplitLocation(Integer.parseInt((String) 
cmdLine.getValue(splitLocationOpt)));
+      }
+      
+      if (cmdLine.hasOption(randomSelectionSizeOpt)) {
+        setTestRandomSelectionSize(Integer.parseInt((String) 
cmdLine.getValue(randomSelectionSizeOpt)));
+      }
+      
+      if (cmdLine.hasOption(randomSelectionPctOpt)) {
+        setTestRandomSelectionPct(Integer.parseInt((String) 
cmdLine.getValue(randomSelectionPctOpt)));
+      }
+
+      trainingOutputDirectory.mkdirs();
+      testOutputDirectory.mkdirs();
+     
+    } catch (OptionException e) {
+      log.error("Command-line option Exception", e);
+      CommandLineUtil.printHelp(group);
+      return false;
+    }
+    
+    validate();
+    return true;
+  }
+  
+  /** Perform a split on directory specified by {...@link 
#setInputDirectory(File)} by calling {...@link #splitFile(File)} 
+   *  on each file found within that directory.
+   *  
+   * @param inputDir
+   * @throws IOException
+   */
+  public void splitDirectory() throws IOException {
+    this.splitDirectory(inputDirectory);
+  }
+  
+  /** Perform a split on the specified directory by calling {...@link 
#splitFile(File)} on each file found within that
+   *  directory.
+   *  
+   * @param inputDir
+   * @throws IOException
+   */
+  public void splitDirectory(File inputDir) throws IOException {
+    if (!inputDir.isDirectory()) 
+      throw new IOException(inputDir + " does not exist, or is not a 
directory");
+    
+    // input dir contains one file per category.
+    File[] inputFiles = inputDir.listFiles();
+    for (File inputFile : inputFiles) {
+      if (inputFile.isFile())
+        splitFile(inputFile);
+    }
+  }
+
+  /** Perform a split on the specified input file. Results will be written to 
files of the same name in the specified 
+   *  training and test output directories. The {...@link #validate()} method 
is called prior to executing the split.
+   *  
+   * @param inputFile
+   * @throws IOException
+   */
+  public void splitFile(File inputFile) throws IOException {
+    if (!inputFile.isFile()) 
+      throw new IOException(inputFile + " does not exist, or is not a file");
+    
+    validate();
+    
+    File testOutputFile = new File(testOutputDirectory, inputFile.getName());
+    File trainingOutputFile = new File(trainingOutputDirectory, 
inputFile.getName());
+    
+    int lineCount = countLines(inputFile, charset);
+    
+    log.info(inputFile.getName() + " has " + lineCount + " lines");
+    
+    int testSplitStart = 0;
+    int testSplitSize  = this.testSplitSize; // don't modify state
+    BitSet randomSel = null;
+    
+    if (testRandomSelectionPct > 0 || testRandomSelectionSize > 0) {
+      testSplitSize = this.testRandomSelectionSize;
+      
+      if (testRandomSelectionPct > 0) {
+        testSplitSize = Math.round(lineCount * ( testRandomSelectionPct / 
100.0f ));
+      }
+      log.info(inputFile.getName() + " test split size is " + testSplitSize + 
+          " based on random selection percentage " + testRandomSelectionPct);
+      long[] ridx = new long[testSplitSize];
+      RandomSampler.sample(testSplitSize, lineCount-1, testSplitSize, 0, ridx, 
0, new MersenneTwister(new Date()));
+      randomSel = new BitSet(lineCount);
+      for (int i=0; i < ridx.length; i++) {
+        randomSel.set((int) ridx[i] + 1);
+      }
+    }
+    else {
+      if (testSplitPct > 0) { // calculate split size based on percentage
+        testSplitSize = Math.round(lineCount * ( testSplitPct / 100.0f ));
+        log.info(inputFile.getName() + " test split size is " + testSplitSize 
+ 
+            " based on percentage " + testSplitPct);
+      }
+      else {
+        log.info(inputFile.getName() + " test split size is " + testSplitSize);
+      }
+      
+      if (splitLocation > 0) { // calculate start of split based on percentage
+        testSplitStart =  Math.round(lineCount * ( splitLocation / 100.0f ));
+        if (lineCount - testSplitStart < testSplitSize) {
+          // adjust split start downwards based on split size.
+          testSplitStart = lineCount - testSplitSize;
+        }
+        log.info(inputFile.getName() + " test split start is " + 
testSplitStart + 
+            " based on split location " + splitLocation);
+      }
+      
+      if (testSplitStart < 0) {
+        throw new IllegalArgumentException("test split size for " + inputFile 
+ " is too large, it would produce an "
+            + "empty training set from the initial set of " + lineCount + " 
examples");
+      }
+      else if ((lineCount - testSplitSize) < testSplitSize) {
+        log.warn("CAUTION: Test set size for " + inputFile + " may be too 
large, " + testSplitSize + 
+            " is larger than the number of lines remaining in the training 
set: " + (lineCount - testSplitSize));
+      }
+    }
+    BufferedReader reader = new BufferedReader(new InputStreamReader(new 
FileInputStream(inputFile), charset));
+    Writer trainingWriter = new OutputStreamWriter(new 
FileOutputStream(trainingOutputFile), charset);
+    Writer testWriter     = new OutputStreamWriter(new 
FileOutputStream(testOutputFile), charset);
+    Writer writer;
+    
+    int pos = 0;
+    int trainCount = 0;
+    int testCount = 0;
+
+    String line;
+    while ((line = reader.readLine()) != null) {
+      pos++;
+      
+      if (testRandomSelectionPct > 0) { // Randomly choose
+        writer =  randomSel.get(pos) ? testWriter : trainingWriter;
+      }
+      else { // Choose based on location
+        writer = pos > testSplitStart ? testWriter : trainingWriter;
+      }
+      
+      if (writer == testWriter) {
+        if (testCount >= testSplitSize) 
+          writer = trainingWriter;
+        else
+          testCount++;
+      }
+      
+      if (writer == trainingWriter) {
+        trainCount++;
+      }
+      
+      writer.write(line);
+      writer.write('\n');
+    }
+    
+    IOUtils.quietClose(trainingWriter);
+    IOUtils.quietClose(testWriter);
+    
+    log.info("file: " + inputFile.getName()+ ", input: " + lineCount + " 
train: " + trainCount + ", test: " + 
+        testCount + " starting at " + testSplitStart);
+    
+    // testing;
+    if (callback != null) {
+      callback.splitComplete(inputFile, lineCount, trainCount, testCount, 
testSplitStart);
+    }
+  }
+  
+  public int getTestSplitSize() {
+    return testSplitSize;
+  }
+
+  public void setTestSplitSize(int testSplitSize) {
+    this.testSplitSize = testSplitSize;
+  }
+
+  public int getTestSplitPct() {
+    return testSplitPct;
+  }
+
+  /** Sets the percentage of the input data to allocate to the test split 
+   * 
+   * @param testSplitPct 
+   *   a value between 0 and 100 inclusive.
+   */
+  public void setTestSplitPct(int testSplitPct) {
+    this.testSplitPct = testSplitPct;
+  }
+
+  public int getSplitLocation() {
+    return splitLocation;
+  }
+
+  /** Set the location of the start of the test/training data split. Expressed 
as percentage of lines, for example
+   *  0 indicates that the test data should be taken from the start of the 
file, 100 indicates that the test data 
+   *  should be taken from the end of the input file, while 25 indicates that 
the test data should be taken from the 
+   *  first quarter of the file. 
+   *  <p>
+   *  This option is only relevant in cases where random selection is not 
employed
+   * 
+   * @param splitLocation 
+   *   a value between 0 and 100 inclusive.
+   */
+  public void setSplitLocation(int splitLocation) {
+    this.splitLocation = splitLocation;
+  }
+
+  public Charset getCharset() {
+    return charset;
+  }
+
+  /** Set the charset used to read and write files 
+   * 
+   * @param charset
+   */
+  public void setCharset(Charset charset) {
+    this.charset = charset;
+  }
+
+  public File getInputDirectory() {
+    return inputDirectory;
+  }
+
+  /** Set the directory from which input data will be read when the the 
{...@link #splitDirectory()} method is invoked
+   * 
+   * @param inputDir
+   */
+  public void setInputDirectory(File inputDir) {
+    this.inputDirectory = inputDir;
+  }
+
+  public File getTrainingOutputDirectory() {
+    return trainingOutputDirectory;
+  }
+
+  /** Set the directory to which training data will be written.
+   * 
+   * @param trainingOutputDir
+   */
+  public void setTrainingOutputDirectory(File trainingOutputDir) {
+    this.trainingOutputDirectory = trainingOutputDir;
+  }
+
+  public File getTestOutputDirectory() {
+    return testOutputDirectory;
+  }
+
+  /** Set the directory to which test data will be written.
+   * 
+   * @param testOutputDir
+   */
+  public void setTestOutputDirectory(File testOutputDir) {
+    this.testOutputDirectory = testOutputDir;
+  }
+
+  public SplitCallback getCallback() {
+    return callback;
+  }
+
+  /** Sets the callback used to inform the caller that an input file has been 
successfully split
+   * 
+   * @param callback
+   */
+  public void setCallback(SplitCallback callback) {
+    this.callback = callback;
+  }
+
+  public int getTestRandomSelectionSize() {
+    return testRandomSelectionSize;
+  }
+
+  /** Sets number of random input samples that will be saved to the test set.
+   * 
+   * @param testRandomSelectionSize 
+   */
+  public void setTestRandomSelectionSize(int testRandomSelectionSize) {
+    this.testRandomSelectionSize = testRandomSelectionSize;
+  }
+
+  public int getTestRandomSelectionPct() {
+
+    return testRandomSelectionPct;
+  }
+
+  /** Sets number of random input samples that will be saved to the test set 
as a percentage of the size of the 
+   *  input set.
+   * 
+   * @param testRandomSelectionPct 
+   *   a value between 0 and 100 inclusive.
+   */
+  public void setTestRandomSelectionPct(int randomSelectionPct) {
+    this.testRandomSelectionPct = randomSelectionPct;
+  }
+
+  /** Validates that the current instance is in a consistent state
+   * 
+   * @throws IllegalStateException 
+   *   if settings violate class invariants.
+   *   
+   * @throws IOException 
+   *   if output directories do not exist or are not directories.
+   *   
+   * @throws NullPointerException 
+   *   if output directories are not set.
+   */
+  public void validate() throws IOException {
+    if ((testSplitSize < 1) && (testSplitSize != -1))
+      throw new IllegalStateException("test split size must be 1 or greater");
+    
+    if ((splitLocation < 0 || splitLocation > 100) && (splitLocation != -1))
+      throw new IllegalStateException("test split percentage must be between 0 
and 100");
+    
+    if ((testSplitPct < 0 || testSplitPct > 100) && (testSplitPct != -1))
+      throw new IllegalStateException("test split percentage must be between 0 
and 100");
+    
+    if ((splitLocation < 0 || splitLocation > 100) && (splitLocation != -1))
+      throw new IllegalStateException("test split percentage must be between 0 
and 100");
+    
+    if ((testRandomSelectionPct < 0 || testRandomSelectionPct > 100) && 
(testRandomSelectionPct != -1))
+      throw new IllegalStateException("test split percentage must be between 0 
and 100");
+
+    // only one of the following may be set, one must be set.
+    int count = 0;
+    if (testSplitSize > 0) count++;
+    if (testSplitPct  > 0) count++;
+    if (testRandomSelectionSize > 0) count++;
+    if (testRandomSelectionPct > 0) count++;
+    
+    if (count == 0) {
+      throw new IllegalStateException("either test split size, test split pct, 
" +
+        "random selection size or random selection pct must be specified and a 
positive integer");
+    }
+    else if (count > 1) {
+      throw new IllegalStateException("only test split size, test split pct, " 
+
+        "random selection size or random selection pct may be specified");
+    }
+
+    if (trainingOutputDirectory == null)
+      throw new NullPointerException("no training output directory was 
specified");
+   
+    if (testOutputDirectory == null)
+      throw new NullPointerException("no test output directory was specified");
+    
+    if (!trainingOutputDirectory.isDirectory())
+      throw new IOException(inputDirectory + " does not exist, or is not a 
directory");
+    
+    if (!testOutputDirectory.isDirectory()) 
+      throw new IOException(inputDirectory + " does not exist, or is not a 
directory");
+  }
+  
+  /** Count the lines in the file specified as returned by 
<code>BufferedReader.readLine()</code>
+   * 
+   * @param inputFile 
+   *   the file whose lines will be counted
+   *   
+   * @param charset
+   *   the charset of the file to read
+   *   
+   * @return the number of lines in the input file.
+   * 
+   * @throws IOException 
+   *   if there is a problem opening or reading the file.
+   */
+  public static int countLines(File inputFile, Charset charset) throws 
IOException {
+    BufferedReader countReader = new BufferedReader(new InputStreamReader(new 
FileInputStream(inputFile), charset));
+    int lineCount = 0;
+    while (countReader.readLine() != null) lineCount++;
+    IOUtils.quietClose(countReader);
+    
+    return lineCount;
+  }
+  
+  /** Used to pass information back to a caller once a file has been split 
without the need for a data object */
+  public static interface SplitCallback {
+    public void splitComplete(File inputFile, int lineCount, int trainCount, 
int testCount, int testSplitStart);
+  }
+}

Added: 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java?rev=985413&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java
 (added)
+++ 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/bayes/SplitBayesInputTest.java
 Sat Aug 14 00:30:47 2010
@@ -0,0 +1,240 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.bayes;
+
+import java.io.BufferedWriter;
+import java.io.File;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.io.OutputStreamWriter;
+import java.io.Writer;
+import java.nio.charset.Charset;
+
+import junit.framework.TestCase;
+
+import org.apache.mahout.classifier.ClassifierData;
+import org.apache.mahout.common.IOUtils;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.math.map.OpenObjectIntHashMap;
+
+
+public class SplitBayesInputTest extends MahoutTestCase {
+
+  OpenObjectIntHashMap<String> countMap;
+  
+  Charset charset;
+  File tempInputFile;
+  File tempTrainingDirectory;
+  File tempTestDirectory;
+  File tempInputDirectory;
+  
+  SplitBayesInput si;
+    
+  public void setUp() throws Exception {
+    super.setUp();
+  
+    countMap = new OpenObjectIntHashMap<String>();
+    
+    charset = Charset.forName("UTF-8");
+    tempInputFile = getTestTempFile("bayesinputfile");
+    tempTrainingDirectory = getTestTempDir("bayestrain");
+    tempTestDirectory = getTestTempDir("bayestest");
+    tempInputDirectory = getTestTempDir("bayesinputdir");
+    
+    si = new SplitBayesInput();
+    si.setTrainingOutputDirectory(tempTrainingDirectory);
+    si.setTestOutputDirectory(tempTestDirectory);
+    si.setInputDirectory(tempInputDirectory);
+  }
+  
+  public void writeMultipleInputFiles() throws IOException {
+    Writer writer = null;
+    String currentLabel = null;
+    
+    for (String[] entry : ClassifierData.DATA) {
+      if (!entry[0].equals(currentLabel)) {
+        currentLabel = entry[0];
+        if (writer != null) IOUtils.quietClose(writer);
+        writer = new BufferedWriter(
+            new OutputStreamWriter(
+                new FileOutputStream(new File(tempInputDirectory, 
currentLabel)), Charset.forName("UTF-8")));
+      }
+      countMap.adjustOrPutValue(currentLabel, 1, 1);
+      writer.write(currentLabel + '\t' + entry[1] + '\n');
+    }
+    IOUtils.quietClose(writer);
+  }
+
+  public void writeSingleInputFile() throws IOException {
+    BufferedWriter writer = new BufferedWriter(
+        new OutputStreamWriter(new FileOutputStream(tempInputFile), 
Charset.forName("UTF-8")));
+    for (String[] entry : ClassifierData.DATA) {
+      writer.write(entry[0] + '\t' + entry[1] + '\n');
+    }
+    writer.close();
+  }
+  
+  public void testSplitDirectory() throws Exception {
+    final int testSplitSize = 1;
+    
+    writeMultipleInputFiles();
+    
+    si.setTestSplitSize(testSplitSize);
+    si.setCallback(new SplitBayesInput.SplitCallback() {
+          @Override
+          public void splitComplete(File inputFile, int lineCount, int 
trainCount, int testCount, int testSplitStart) {
+            int trainingLines = countMap.get(inputFile.getName()) - 
testSplitSize;
+            try {
+              assertSplit(inputFile, charset, testSplitSize, trainingLines, 
tempTrainingDirectory, tempTestDirectory);
+            }
+            catch (Exception e) {
+              throw new RuntimeException(e);
+            }
+          }
+    });
+    
+    si.splitDirectory(tempInputDirectory);
+  }
+  
+  public void testSplitFile() throws Exception {
+    writeSingleInputFile();
+    si.setTestSplitSize(2);
+    si.setCallback(new TestCallback(2, 10));
+    si.splitFile(tempInputFile);
+  }
+  
+  public void testSplitFileLocation() throws Exception {
+    writeSingleInputFile();
+    si.setTestSplitSize(2);
+    si.setSplitLocation(50);
+    si.setCallback(new TestCallback(2, 10));
+    si.splitFile(tempInputFile);
+  }
+  
+  public void testSplitFilePct() throws Exception {
+    writeSingleInputFile();
+    si.setTestSplitPct(25);
+   
+    si.setCallback(new TestCallback(3, 9));
+    si.splitFile(tempInputFile);
+  }
+  
+  public void testSplitFilePctLocation() throws Exception {
+    writeSingleInputFile();
+    si.setTestSplitPct(25);
+    si.setSplitLocation(50);
+    si.setCallback(new TestCallback(3, 9));
+    si.splitFile(tempInputFile);
+  }
+  
+  public void testSplitFileRandomSelectionSize() throws Exception {
+    writeSingleInputFile();
+    si.setTestRandomSelectionSize(5);
+   
+    si.setCallback(new TestCallback(5, 7));
+    si.splitFile(tempInputFile);
+  }
+  
+  public void testSplitFileRandomSelectionPct() throws Exception {
+    writeSingleInputFile();
+    si.setTestRandomSelectionPct(25);
+   
+    si.setCallback(new TestCallback(3, 9));
+    si.splitFile(tempInputFile);
+  }
+  
+  public void testValidate() throws Exception {
+    SplitBayesInput st = new SplitBayesInput();
+    assertValidateException(st, IllegalStateException.class);
+    
+    st.setTestSplitSize(100);
+    assertValidateException(st, NullPointerException.class); 
+    
+    st.setTestOutputDirectory(tempTestDirectory);
+    assertValidateException(st, NullPointerException.class); 
+    
+    st.setTrainingOutputDirectory(tempTrainingDirectory);
+    st.validate();
+    
+    st.setTestSplitPct(50);
+    assertValidateException(st, IllegalStateException.class);
+    
+    st = new SplitBayesInput();
+    st.setTestRandomSelectionPct(50);
+    st.setTestOutputDirectory(tempTestDirectory);
+    st.setTrainingOutputDirectory(tempTrainingDirectory);
+    st.validate();
+    
+    st.setTestSplitPct(50);
+    assertValidateException(st, IllegalStateException.class);
+    
+    st = new SplitBayesInput();
+    st.setTestRandomSelectionPct(50);
+    st.setTestOutputDirectory(tempTestDirectory);
+    st.setTrainingOutputDirectory(tempTrainingDirectory);
+    st.validate();
+    
+    st.setTestSplitSize(100);
+    assertValidateException(st, IllegalStateException.class);
+  }
+  
+  private class TestCallback implements SplitBayesInput.SplitCallback {
+    int testSplitSize;
+    int trainingLines;
+    
+    public TestCallback(int testSplitSize, int trainingLines) {
+      this.testSplitSize = testSplitSize;
+      this.trainingLines = trainingLines;
+    }
+    
+    @Override
+    public void splitComplete(File inputFile, int lineCount, int trainCount, 
int testCount, int testSplitStart) {
+      try {
+        assertSplit(tempInputFile, charset, testSplitSize, trainingLines, 
tempTrainingDirectory, tempTestDirectory);
+      }
+      catch (Exception e) {
+        throw new RuntimeException(e);
+      }
+    }
+  }
+  
+  private void assertValidateException(SplitBayesInput st, Class<?> clazz) {
+    try {
+      st.validate();
+      TestCase.fail("Expected valdate() to throw an exception, received none");
+    }
+    catch (Exception e) {
+      if (!e.getClass().isAssignableFrom(clazz)) {
+        e.printStackTrace();
+        TestCase.fail("Unexpected exception. Expected " + clazz.getName() + " 
received " + e.getClass().getName());
+      }
+    } 
+  }
+  
+  private void assertSplit(File tempInputFile, Charset charset, int 
testSplitSize, int trainingLines, 
+      File tempTrainingDirectory, File tempTestDirectory) throws Exception {
+    
+    File testFile = new File(tempTestDirectory, tempInputFile.getName());
+    TestCase.assertTrue("test file exists", testFile.isFile());
+    TestCase.assertEquals("test line count", testSplitSize, 
SplitBayesInput.countLines(testFile,charset));
+    
+    File trainingFile = new File(tempTrainingDirectory, 
tempInputFile.getName());
+    TestCase.assertTrue("training file exists", trainingFile.isFile());
+    TestCase.assertEquals("training line count", trainingLines, 
SplitBayesInput.countLines(trainingFile, charset));  
+  }
+}


Reply via email to