Author: adeneche Date: Sat Mar 13 16:25:13 2010 New Revision: 922594 URL: http://svn.apache.org/viewvc?rev=922594&view=rev Log: MAHOUT-323: TestForest can store the predictions in a file
Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Modified: lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java URL: http://svn.apache.org/viewvc/lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java?rev=922594&r1=922593&r2=922594&view=diff ============================================================================== --- lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java (original) +++ lucene/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/TestForest.java Sat Mar 13 16:25:13 2010 @@ -35,6 +35,7 @@ import org.apache.hadoop.conf.Configured import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.hadoop.fs.FSDataInputStream; +import org.apache.hadoop.fs.FSDataOutputStream; import org.apache.hadoop.util.Tool; import org.apache.hadoop.util.ToolRunner; import org.apache.mahout.common.CommandLineUtil; @@ -65,6 +66,10 @@ public class TestForest extends Configur private Path modelPath; // path where the forest is stored + private Path outputPath; // path to predictions file, if null do not output the predictions + + private boolean analyze; // analyze the classification results ? + @Override public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException { @@ -83,11 +88,17 @@ public class TestForest extends Configur abuilder.withName("path").withMinimum(1).withMaximum(1).create()). withDescription("Path to the Decision Forest").create(); + Option outputOpt = obuilder.withLongName("output").withShortName("o").withRequired(false).withArgument( + abuilder.withName("output").withMinimum(1).withMaximum(1).create()).withDescription( + "Path to generated predictions file").create(); + + Option analyzeOpt = obuilder.withLongName("analyze").withShortName("a").withRequired(false).create(); + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") .create(); Group group = gbuilder.withName("Options").withOption(inputOpt).withOption(datasetOpt) - .withOption(modelOpt).withOption(helpOpt).create(); + .withOption(modelOpt).withOption(outputOpt).withOption(analyzeOpt).withOption(helpOpt).create(); try { Parser parser = new Parser(); @@ -102,15 +113,21 @@ public class TestForest extends Configur String dataName = cmdLine.getValue(inputOpt).toString(); String datasetName = cmdLine.getValue(datasetOpt).toString(); String modelName = cmdLine.getValue(modelOpt).toString(); + String outputName = (cmdLine.hasOption(outputOpt)) ? cmdLine.getValue(outputOpt).toString() : null; + analyze = cmdLine.hasOption(analyzeOpt); - log.debug("inout : {}", dataName); + log.debug("inout : {}", dataName); log.debug("dataset : {}", datasetName); - log.debug("model : {}", modelName); + log.debug("model : {}", modelName); + log.debug("output : {}", outputName); + log.debug("analyze : {}", analyze); dataPath = new Path(dataName); datasetPath = new Path(datasetName); modelPath = new Path(modelName); - + if (outputName != null) { + outputPath = new Path(outputName); + } } catch (OptionException e) { System.err.println("Exception : " + e); CommandLineUtil.printHelp(group); @@ -123,6 +140,17 @@ public class TestForest extends Configur } private void testForest() throws IOException, ClassNotFoundException, InterruptedException { + + FileSystem ofs = null; + + // make sure the output file does not exist + if (outputPath != null) { + ofs = outputPath.getFileSystem(getConf()); + if (ofs.exists(outputPath)) { + throw new IllegalArgumentException("Output path already exists"); + } + } + Dataset dataset = Dataset.load(getConf(), datasetPath); DataConverter converter = new DataConverter(dataset); @@ -146,6 +174,9 @@ public class TestForest extends Configur return; } + // create the predictions file + FSDataOutputStream ofile = (outputPath != null) ? ofs.create(outputPath) : null; + log.info("Sequential classification..."); long time = System.currentTimeMillis(); @@ -153,7 +184,7 @@ public class TestForest extends Configur FSDataInputStream input = tfs.open(dataPath); Scanner scanner = new Scanner(input); Random rng = RandomUtils.getRandom(); - ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown"); + ResultAnalyzer analyzer = (analyze) ? new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown") : null; while (scanner.hasNextLine()) { String line = scanner.nextLine(); @@ -164,13 +195,22 @@ public class TestForest extends Configur Instance instance = converter.convert(0, line); int prediction = forest.classify(rng, instance); - analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0)); + if (outputPath != null) { + ofile.writeChars(Integer.toString(prediction)); // write the prediction + ofile.writeChar('\n'); + } + + if (analyze) { + analyzer.addInstance(dataset.getLabel(instance.label), new ClassifierResult(dataset.getLabel(prediction), 1.0)); + } } time = System.currentTimeMillis() - time; log.info("Classification Time: {}", DFUtils.elapsedTime(time)); - log.info(analyzer.summarize()); + if (analyze) { + log.info(analyzer.summarize()); + } } /**