Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java (original) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/split/Split.java Sun Dec 11 17:53:50 2011 @@ -38,17 +38,23 @@ public final class Split { this(attr, ig, Double.NaN); } - /** attribute to split for */ + /** + * @return attribute to split for + */ public int getAttr() { return attr; } - /** Information Gain of the split */ + /** + * @return Information Gain of the split + */ public double getIg() { return ig; } - /** split value for NUMERICAL attributes */ + /** + * @return split value for NUMERICAL attributes + */ public double getSplit() { return split; }
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java (original) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/Frequencies.java Sun Dec 11 17:53:50 2011 @@ -39,7 +39,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Compute the frequency distribution of the "class label" + * Compute the frequency distribution of the "class label"<br> + * This class can be used when the criterion variable is the categorical attribute. */ public final class Frequencies extends Configured implements Tool { Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java (original) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/FrequenciesJob.java Sun Dec 11 17:53:50 2011 @@ -50,7 +50,8 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; /** - * Temporary class used to compute the frequency distribution of the "class attribute". + * Temporary class used to compute the frequency distribution of the "class attribute".<br> + * This class can be used when the criterion variable is the categorical attribute. */ public class FrequenciesJob { @@ -124,7 +125,7 @@ public class FrequenciesJob { * * @return counts[partition][label] = num tuples from 'partition' with class == label */ - protected int[][] parseOutput(JobContext job) throws IOException { + int[][] parseOutput(JobContext job) throws IOException { Configuration conf = job.getConfiguration(); int numMaps = conf.getInt("mapred.map.tasks", -1); @@ -176,7 +177,7 @@ public class FrequenciesJob { /** * Useful when testing */ - protected void setup(Dataset dataset) { + void setup(Dataset dataset) { converter = new DataConverter(dataset); } @@ -189,7 +190,7 @@ public class FrequenciesJob { Instance instance = converter.convert(value.toString()); - context.write(firstId, new IntWritable(dataset.getLabel(instance))); + context.write(firstId, new IntWritable((int) dataset.getLabel(instance))); } } @@ -208,7 +209,7 @@ public class FrequenciesJob { /** * Useful when testing */ - protected void setup(int nblabels) { + void setup(int nblabels) { this.nblabels = nblabels; } @@ -236,7 +237,9 @@ public class FrequenciesJob { /** counts[c] = num tuples from the partition with label == c */ private int[] counts; - protected Frequencies(long firstId, int[] counts) { + public Frequencies() { } + + Frequencies(long firstId, int[] counts) { this.firstId = firstId; this.counts = Arrays.copyOf(counts, counts.length); } Modified: mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java (original) +++ mahout/trunk/core/src/main/java/org/apache/mahout/classifier/df/tools/UDistrib.java Sun Dec 11 17:53:50 2011 @@ -50,7 +50,8 @@ import com.google.common.base.Preconditi /** * This tool is used to uniformly distribute the class of all the tuples of the dataset over a given number of - * partitions. + * partitions.<br> + * This class can be used when the criterion variable is the categorical attribute. */ public final class UDistrib { @@ -63,7 +64,8 @@ public final class UDistrib { * Launch the uniform distribution tool. Requires the following command line arguments:<br> * * data : data path dataset : dataset path numpartitions : num partitions output : output path - * + * + * @throws java.io.IOException */ public static void main(String[] args) throws IOException { @@ -175,7 +177,7 @@ public final class UDistrib { // write the tuple in files[tuple.label] Instance instance = converter.convert(line); - int label = dataset.getLabel(instance); + int label = (int) dataset.getLabel(instance); files[currents[label]].writeBytes(line); files[currents[label]].writeChar('\n'); Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataConverterTest.java Sun Dec 11 17:53:50 2011 @@ -44,5 +44,17 @@ public final class DataConverterTest ext for (int index = 0; index < data.size(); index++) { assertEquals(data.get(index), converter.convert(sData[index])); } + + // regression + source = Utils.randomDoubles(rng, descriptor, true, INSTANCE_COUNT); + sData = Utils.double2String(source); + dataset = DataLoader.generateDataset(descriptor, true, sData); + data = DataLoader.loadData(dataset, sData); + + converter = new DataConverter(dataset); + + for (int index = 0; index < data.size(); index++) { + assertEquals(data.get(index), converter.convert(sData[index])); + } } } Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataLoaderTest.java Sun Dec 11 17:53:50 2011 @@ -57,6 +57,16 @@ public final class DataLoaderTest extend testLoadedData(data, attrs, missings, loaded); testLoadedDataset(data, attrs, missings, loaded); + + // regression + data = Utils.randomDoubles(rng, descriptor, true, datasize); + missings = Lists.newArrayList(); + sData = prepareData(data, attrs, missings); + dataset = DataLoader.generateDataset(descriptor, true, sData); + loaded = DataLoader.loadData(dataset, sData); + + testLoadedData(data, attrs, missings, loaded); + testLoadedDataset(data, attrs, missings, loaded); } /** @@ -81,7 +91,17 @@ public final class DataLoaderTest extend Dataset dataset = DataLoader.generateDataset(descriptor, false, sData); assertEquals(expected, dataset); - } + + // regression + data = Utils.randomDoubles(rng, descriptor, true, datasize); + missings = Lists.newArrayList(); + sData = prepareData(data, attrs, missings); + expected = DataLoader.generateDataset(descriptor, true, sData); + + dataset = DataLoader.generateDataset(descriptor, true, sData); + + assertEquals(expected, dataset); +} /** * Converts the data to an array of comma-separated strings and adds some @@ -153,14 +173,21 @@ public final class DataLoaderTest extend } if (attrs[attr].isNumerical()) { - assertEquals(vector[attr], instance.get(aId++), EPSILON); - } else if (attrs[attr].isCategorical()||attrs[attr].isLabel()) { + assertEquals(vector[attr], instance.get(aId), EPSILON); + aId++; + } else if (attrs[attr].isCategorical()) { checkCategorical(data, missings, loaded, attr, aId, vector[attr], instance.get(aId)); aId++; - } /*else if (attrs[attr].isLabel()) { - checkLabel(data, missings, loaded, attr, vector[attr]); - }*/ + } else if (attrs[attr].isLabel()) { + if (loaded.getDataset().isNumerical(aId)) { + assertEquals(vector[attr], instance.get(aId), EPSILON); + } else { + checkCategorical(data, missings, loaded, attr, aId, vector[attr], + instance.get(aId)); + } + aId++; + } } lind++; @@ -193,14 +220,21 @@ public final class DataLoaderTest extend continue; } - assertEquals(attrs[attr].isNumerical(), loaded.getDataset().isNumerical(aId)); - - if (attrs[attr].isCategorical()) { - double nValue = instance.get(aId); - String oValue = Double.toString(data[index][attr]); - assertEquals((double) loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON); + if (attrs[attr].isLabel()) { + if (!loaded.getDataset().isNumerical(aId)) { + double nValue = instance.get(aId); + String oValue = Double.toString(data[index][attr]); + assertEquals((double) loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON); + } + } else { + assertEquals(attrs[attr].isNumerical(), loaded.getDataset().isNumerical(aId)); + + if (attrs[attr].isCategorical()) { + double nValue = instance.get(aId); + String oValue = Double.toString(data[index][attr]); + assertEquals((double) loaded.getDataset().valueOf(aId, oValue), nValue, EPSILON); + } } - aId++; } } @@ -227,7 +261,19 @@ public final class DataLoaderTest extend Data loaded = DataLoader.loadData(dataset, fs, dataPath); testLoadedData(source, attrs, missings, loaded); - } + + // regression + source = Utils.randomDoubles(rng, descriptor, true, datasize); + missings = Lists.newArrayList(); + sData = prepareData(source, attrs, missings); + dataset = DataLoader.generateDataset(descriptor, true, sData); + + dataPath = Utils.writeDataToTestFile(sData); + fs = dataPath.getFileSystem(new Configuration()); + loaded = DataLoader.loadData(dataset, fs, dataPath); + + testLoadedData(source, attrs, missings, loaded); +} /** * Test method for @@ -254,6 +300,19 @@ public final class DataLoaderTest extend Dataset dataset = DataLoader.generateDataset(descriptor, false, fs, path); assertEquals(expected, dataset); + + // regression + source = Utils.randomDoubles(rng, descriptor, false, datasize); + missings = Lists.newArrayList(); + sData = prepareData(source, attrs, missings); + expected = DataLoader.generateDataset(descriptor, false, sData); + + path = Utils.writeDataToTestFile(sData); + fs = path.getFileSystem(new Configuration()); + + dataset = DataLoader.generateDataset(descriptor, false, fs, path); + + assertEquals(expected, dataset); } /** @@ -288,38 +347,4 @@ public final class DataLoaderTest extend lind++; } } - - /** - * each time value appears in data as a label, its corresponding code must - * appear in all the instances with the same label. - * - * @param labelInd label's index in source - * @param value source label's value - */ - static void checkLabel(double[][] source, - Collection<Integer> missings, - Data loaded, - int labelInd, - double value) { - Dataset dataset = loaded.getDataset(); - - // label's code that corresponds to the value - int code = loaded.getDataset().labelCode(Double.toString(value)); - - int lind = 0; - - for (int index = 0; index < source.length; index++) { - if (missings.contains(index)) { - continue; - } - - if (source[index][labelInd] == value) { - assertEquals(code, dataset.getLabel(loaded.get(lind))); - } else { - assertFalse(code == dataset.getLabel(loaded.get(lind))); - } - - lind++; - } - } } Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DataTest.java Sun Dec 11 17:53:50 2011 @@ -33,13 +33,16 @@ public class DataTest extends MahoutTest private Random rng; - private Data data; + private Data classifierData; + + private Data regressionData; @Override public void setUp() throws Exception { super.setUp(); rng = RandomUtils.getRandom(); - data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE); + classifierData = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE); + regressionData = Utils.randomData(rng, ATTRIBUTE_COUNT, true, DATA_SIZE); } /** @@ -51,17 +54,45 @@ public class DataTest extends MahoutTest int n = 10; for (int nloop = 0; nloop < n; nloop++) { - int attr = rng.nextInt(data.getDataset().nbAttributes()); + int attr = rng.nextInt(classifierData.getDataset().nbAttributes()); - double[] values = data.values(attr); + double[] values = classifierData.values(attr); double value = values[rng.nextInt(values.length)]; - Data eSubset = data.subset(Condition.equals(attr, value)); - Data lSubset = data.subset(Condition.lesser(attr, value)); - Data gSubset = data.subset(Condition.greaterOrEquals(attr, value)); + Data eSubset = classifierData.subset(Condition.equals(attr, value)); + Data lSubset = classifierData.subset(Condition.lesser(attr, value)); + Data gSubset = classifierData.subset(Condition.greaterOrEquals(attr, value)); for (int index = 0; index < DATA_SIZE; index++) { - Instance instance = data.get(index); + Instance instance = classifierData.get(index); + + if (instance.get(attr) < value) { + assertTrue(lSubset.contains(instance)); + assertFalse(eSubset.contains(instance)); + assertFalse(gSubset.contains(instance)); + } else if (instance.get(attr) == value) { + assertFalse(lSubset.contains(instance)); + assertTrue(eSubset.contains(instance)); + assertTrue(gSubset.contains(instance)); + } else { + assertFalse(lSubset.contains(instance)); + assertFalse(eSubset.contains(instance)); + assertTrue(gSubset.contains(instance)); + } + } + + // regression + attr = rng.nextInt(regressionData.getDataset().nbAttributes()); + + values = regressionData.values(attr); + value = values[rng.nextInt(values.length)]; + + eSubset = regressionData.subset(Condition.equals(attr, value)); + lSubset = regressionData.subset(Condition.lesser(attr, value)); + gSubset = regressionData.subset(Condition.greaterOrEquals(attr, value)); + + for (int index = 0; index < DATA_SIZE; index++) { + Instance instance = regressionData.get(index); if (instance.get(attr) < value) { assertTrue(lSubset.contains(instance)); @@ -82,17 +113,23 @@ public class DataTest extends MahoutTest @Test public void testValues() throws Exception { - Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE); - - for (int attr = 0; attr < data.getDataset().nbAttributes(); attr++) { - double[] values = data.values(attr); + for (int attr = 0; attr < classifierData.getDataset().nbAttributes(); attr++) { + double[] values = classifierData.values(attr); // each value of the attribute should appear exactly one time in values for (int index = 0; index < DATA_SIZE; index++) { - assertEquals(1, count(values, data.get(index).get(attr))); + assertEquals(1, count(values, classifierData.get(index).get(attr))); } } + for (int attr = 0; attr < regressionData.getDataset().nbAttributes(); attr++) { + double[] values = regressionData.values(attr); + + // each value of the attribute should appear exactly one time in values + for (int index = 0; index < DATA_SIZE; index++) { + assertEquals(1, count(values, regressionData.get(index).get(attr))); + } + } } private static int count(double[] values, double value) { @@ -194,19 +231,33 @@ public class DataTest extends MahoutTest */ @Test public void testBagging() { - Data bag = data.bagging(rng); + Data bag = classifierData.bagging(rng); // the bag should have the same size as the data - assertEquals(data.size(), bag.size()); + assertEquals(classifierData.size(), bag.size()); // at least one element from the data should not be in the bag boolean found = false; - for (int index = 0; index < data.size() && !found; index++) { - found = !bag.contains(data.get(index)); + for (int index = 0; index < classifierData.size() && !found; index++) { + found = !bag.contains(classifierData.get(index)); } assertTrue("some instances from data should not be in the bag", found); - } + + // regression + bag = regressionData.bagging(rng); + + // the bag should have the same size as the data + assertEquals(regressionData.size(), bag.size()); + + // at least one element from the data should not be in the bag + found = false; + for (int index = 0; index < regressionData.size() && !found; index++) { + found = !bag.contains(regressionData.get(index)); + } + + assertTrue("some instances from data should not be in the bag", found); +} /** * Test method for @@ -216,42 +267,61 @@ public class DataTest extends MahoutTest public void testRsplit() { // rsplit should handle empty subsets - Data source = data.clone(); + Data source = classifierData.clone(); Data subset = source.rsplit(rng, 0); assertTrue("subset should be empty", subset.isEmpty()); assertEquals("source.size is incorrect", DATA_SIZE, source.size()); // rsplit should handle full size subsets - source = data.clone(); + source = classifierData.clone(); subset = source.rsplit(rng, DATA_SIZE); assertEquals("subset.size is incorrect", DATA_SIZE, subset.size()); assertTrue("source should be empty", source.isEmpty()); // random case int subsize = rng.nextInt(DATA_SIZE); - source = data.clone(); + source = classifierData.clone(); subset = source.rsplit(rng, subsize); assertEquals("subset.size is incorrect", subsize, subset.size()); assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size()); - } + + // regression + // rsplit should handle empty subsets + source = regressionData.clone(); + subset = source.rsplit(rng, 0); + assertTrue("subset should be empty", subset.isEmpty()); + assertEquals("source.size is incorrect", DATA_SIZE, source.size()); + + // rsplit should handle full size subsets + source = regressionData.clone(); + subset = source.rsplit(rng, DATA_SIZE); + assertEquals("subset.size is incorrect", DATA_SIZE, subset.size()); + assertTrue("source should be empty", source.isEmpty()); + + // random case + subsize = rng.nextInt(DATA_SIZE); + source = regressionData.clone(); + subset = source.rsplit(rng, subsize); + assertEquals("subset.size is incorrect", subsize, subset.size()); + assertEquals("source.size is incorrect", DATA_SIZE - subsize, source.size()); +} @Test public void testCountLabel() throws Exception { - Data data = Utils.randomData(rng, ATTRIBUTE_COUNT, false, DATA_SIZE); - Dataset dataset = data.getDataset(); + Dataset dataset = classifierData.getDataset(); int[] counts = new int[dataset.nblabels()]; int n = 10; for (int nloop = 0; nloop < n; nloop++) { Arrays.fill(counts, 0); - data.countLabels(counts); + classifierData.countLabels(counts); - for (int index=0;index<data.size();index++) { - counts[dataset.getLabel(data.get(index))]--; + for (int index = 0; index < classifierData.size(); index++) { + counts[(int) dataset.getLabel(classifierData.get(index))]--; } - for (int label = 0; label < data.getDataset().nblabels(); label++) { + for (int label = 0; label < classifierData.getDataset().nblabels(); label++) { assertEquals("Wrong label 'equals' count", 0, counts[0]); } } Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/data/DatasetTest.java Sun Dec 11 17:53:50 2011 @@ -56,6 +56,15 @@ public final class DatasetTest extends M dataset.write(out); assertEquals(dataset, readDataset(byteOutStream.toByteArray())); + + // regression + byteOutStream.reset(); + + dataset = Utils.randomData(rng, NUM_ATTRIBUTES, true, 1).getDataset(); + + dataset.write(out); + + assertEquals(dataset, readDataset(byteOutStream.toByteArray())); } } Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/mapreduce/partial/PartialSequentialBuilder.java Sun Dec 11 17:53:50 2011 @@ -67,13 +67,13 @@ public class PartialSequentialBuilder ex } @Override - protected void configureJob(Job job, int nbTrees) + protected void configureJob(Job job) throws IOException { Configuration conf = job.getConfiguration(); int num = conf.getInt("mapred.map.tasks", -1); - super.configureJob(job, nbTrees); + super.configureJob(job); // PartialBuilder sets the number of maps to 1 if we are running in 'local' conf.setInt("mapred.map.tasks", num); Modified: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java (original) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/node/NodeTest.java Sun Dec 11 17:53:50 2011 @@ -56,13 +56,13 @@ public final class NodeTest extends Maho public void testReadTree() throws Exception { Node node1 = new CategoricalNode(rng.nextInt(), new double[] { rng.nextDouble(), rng.nextDouble() }, - new Node[] { new Leaf(rng.nextInt()), new Leaf(rng.nextInt()) }); + new Node[] { new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()) }); Node node2 = new NumericalNode(rng.nextInt(), rng.nextDouble(), - new Leaf(rng.nextInt()), new Leaf(rng.nextInt())); + new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble())); Node root = new CategoricalNode(rng.nextInt(), new double[] { rng.nextDouble(), rng.nextDouble(), rng.nextDouble() }, - new Node[] { node1, node2, new Leaf(rng.nextInt()) }); + new Node[] { node1, node2, new Leaf(rng.nextDouble()) }); // write the node to a DataOutput root.write(out); @@ -80,7 +80,7 @@ public final class NodeTest extends Maho @Test public void testReadLeaf() throws Exception { - Node leaf = new Leaf(rng.nextInt()); + Node leaf = new Leaf(rng.nextDouble()); leaf.write(out); assertEquals(leaf, readNode()); } @@ -89,7 +89,7 @@ public final class NodeTest extends Maho public void testParseNumerical() throws Exception { Node node = new NumericalNode(rng.nextInt(), rng.nextDouble(), new Leaf(rng - .nextInt()), new Leaf(rng.nextInt())); + .nextInt()), new Leaf(rng.nextDouble())); node.write(out); assertEquals(node, readNode()); } @@ -98,8 +98,8 @@ public final class NodeTest extends Maho Node node = new CategoricalNode(rng.nextInt(), new double[]{rng.nextDouble(), rng.nextDouble(), rng.nextDouble()}, new Node[]{ - new Leaf(rng.nextInt()), new Leaf(rng.nextInt()), - new Leaf(rng.nextInt())}); + new Leaf(rng.nextDouble()), new Leaf(rng.nextDouble()), + new Leaf(rng.nextDouble())}); node.write(out); assertEquals(node, readNode()); Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/BreimanExample.java Sun Dec 11 17:53:50 2011 @@ -113,8 +113,8 @@ public class BreimanExample extends Conf numNodesOne += forestOne.nbNodes(); // compute the test set error (Selection Error), and mean tree error (One Tree Error), - int[] testLabels = test.extractLabels(); - int[] predictions = new int[test.size()]; + double[] testLabels = test.extractLabels(); + double[] predictions = new double[test.size()]; forestM.classify(test, predictions); sumTestErrM += ErrorEstimate.errorRate(testLabels, predictions); Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java Sun Dec 11 17:53:50 2011 @@ -37,6 +37,7 @@ import org.apache.mahout.common.CommandL import org.apache.mahout.classifier.df.DFUtils; import org.apache.mahout.classifier.df.DecisionForest; import org.apache.mahout.classifier.df.builder.DefaultTreeBuilder; +import org.apache.mahout.classifier.df.builder.TreeBuilder; import org.apache.mahout.classifier.df.data.Data; import org.apache.mahout.classifier.df.data.DataLoader; import org.apache.mahout.classifier.df.data.Dataset; @@ -65,9 +66,12 @@ public class BuildForest extends Configu private Long seed; // Random seed private boolean isPartial; // use partial data implementation + + private String builderName; // Tree builder class name @Override - public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException { + public int run(String[] args) throws IOException, ClassNotFoundException, InterruptedException, + InstantiationException, IllegalAccessException { DefaultOptionBuilder obuilder = new DefaultOptionBuilder(); ArgumentBuilder abuilder = new ArgumentBuilder(); @@ -99,12 +103,16 @@ public class BuildForest extends Configu abuilder.withName("path").withMinimum(1).withMaximum(1).create()). withDescription("Output path, will contain the Decision Forest").create(); + Option builderOpt = obuilder.withLongName("builder").withShortName("b").withRequired(false) + .withArgument(abuilder.withName("builder").withMinimum(1).withMaximum(1).create()). + withDescription("Tree builder class name").create(); + Option helpOpt = obuilder.withLongName("help").withDescription("Print out help").withShortName("h") .create(); Group group = gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt) .withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt) - .withOption(outputOpt).withOption(helpOpt).create(); + .withOption(outputOpt).withOption(builderOpt).withOption(helpOpt).create(); try { Parser parser = new Parser(); @@ -127,6 +135,10 @@ public class BuildForest extends Configu seed = Long.valueOf(cmdLine.getValue(seedOpt).toString()); } + if (cmdLine.hasOption(builderOpt)) { + builderName = cmdLine.getValue(builderOpt).toString(); + } + if (log.isDebugEnabled()) { log.debug("data : {}", dataName); log.debug("dataset : {}", datasetName); @@ -135,6 +147,7 @@ public class BuildForest extends Configu log.debug("seed : {}", seed); log.debug("nbtrees : {}", nbTrees); log.debug("isPartial : {}", isPartial); + log.debug("builder : {}", builderName); } dataPath = new Path(dataName); @@ -152,7 +165,8 @@ public class BuildForest extends Configu return 0; } - private void buildForest() throws IOException, ClassNotFoundException, InterruptedException { + private void buildForest() throws IOException, ClassNotFoundException, InterruptedException, + InstantiationException, IllegalAccessException { // make sure the output path does not exist FileSystem ofs = outputPath.getFileSystem(getConf()); if (ofs.exists(outputPath)) { @@ -160,8 +174,14 @@ public class BuildForest extends Configu return; } - DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder(); - treeBuilder.setM(m); + TreeBuilder treeBuilder; + if (builderName == null) { + treeBuilder = new DefaultTreeBuilder(); + ((DefaultTreeBuilder) treeBuilder).setM(m); + } else { + Class<?> clazz = Class.forName(builderName); + treeBuilder = (TreeBuilder) clazz.newInstance(); + } Builder forestBuilder; Modified: mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java URL: http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java?rev=1213034&r1=1213033&r2=1213034&view=diff ============================================================================== --- mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java (original) +++ mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/TestForest.java Sun Dec 11 17:53:50 2011 @@ -18,6 +18,8 @@ package org.apache.mahout.classifier.df.mapreduce; import java.io.IOException; +import java.util.ArrayList; +import java.util.List; import java.util.Random; import java.util.Scanner; import java.util.Arrays; @@ -44,6 +46,7 @@ import org.apache.mahout.common.RandomUt import org.apache.mahout.common.commandline.DefaultOptionCreator; import org.apache.mahout.classifier.df.DFUtils; import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.RegressionResultAnalyzer; import org.apache.mahout.classifier.ResultAnalyzer; import org.apache.mahout.classifier.ClassifierResult; import org.apache.mahout.classifier.df.data.DataConverter; @@ -179,12 +182,27 @@ public class TestForest extends Configur throw new IllegalArgumentException("You must specify the ouputPath when using the mapreduce implementation"); } - Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf(), analyze); + Classifier classifier = new Classifier(modelPath, dataPath, datasetPath, outputPath, getConf()); classifier.run(); if (analyze) { - log.info("{}", classifier.getAnalyzer()); + double[][] results = classifier.getResults(); + if (results != null) { + Dataset dataset = Dataset.load(getConf(), datasetPath); + if (dataset.isNumerical(dataset.getLabelId())) { + RegressionResultAnalyzer regressionAnalyzer = new RegressionResultAnalyzer(); + regressionAnalyzer.setInstances(results); + log.info("{}", regressionAnalyzer); + } else { + ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown"); + for (double[] res : results) { + analyzer.addInstance(dataset.getLabelString(res[0]), + new ClassifierResult(dataset.getLabelString(res[1]), 1.0)); + } + log.info("{}", analyzer); + } + } } } @@ -206,37 +224,49 @@ public class TestForest extends Configur long time = System.currentTimeMillis(); Random rng = RandomUtils.getRandom(); - ResultAnalyzer analyzer = analyze ? new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown") : null; + List<double[]> resList = new ArrayList<double[]>(); if (dataFS.getFileStatus(dataPath).isDir()) { //the input is a directory of files - testDirectory(outputPath, converter, forest, dataset, analyzer, rng); + testDirectory(outputPath, converter, forest, dataset, resList, rng); } else { // the input is one single file - testFile(dataPath, outputPath, converter, forest, dataset, analyzer, rng); + testFile(dataPath, outputPath, converter, forest, dataset, resList, rng); } time = System.currentTimeMillis() - time; log.info("Classification Time: {}", DFUtils.elapsedTime(time)); - if (analyzer != null) { - log.info("{}", analyzer); + if (analyze) { + if (dataset.isNumerical(dataset.getLabelId())) { + RegressionResultAnalyzer regressionAnalyzer = new RegressionResultAnalyzer(); + double[][] results = new double[resList.size()][2]; + regressionAnalyzer.setInstances(resList.toArray(results)); + log.info("{}", regressionAnalyzer); + } else { + ResultAnalyzer analyzer = new ResultAnalyzer(Arrays.asList(dataset.labels()), "unknown"); + for (double[] r : resList) { + analyzer.addInstance(dataset.getLabelString(r[0]), + new ClassifierResult(dataset.getLabelString(r[1]), 1.0)); + } + log.info("{}", analyzer); + } } } - private void testDirectory(Path outPath, DataConverter converter, DecisionForest forest, Dataset dataset, - ResultAnalyzer analyzer, Random rng) throws IOException { + private void testDirectory(Path outPath, DataConverter converter, DecisionForest forest, + Dataset dataset, List<double[]> results, Random rng) throws IOException { Path[] infiles = DFUtils.listOutputFiles(dataFS, dataPath); for (Path path : infiles) { log.info("Classifying : {}", path); Path outfile = outPath != null ? new Path(outPath, path.getName()).suffix(".out") : null; - testFile(path, outfile, converter, forest, dataset, analyzer, rng); + testFile(path, outfile, converter, forest, dataset, results, rng); } } - private void testFile(Path inPath, Path outPath, DataConverter converter, DecisionForest forest, Dataset dataset, - ResultAnalyzer analyzer, Random rng) throws IOException { + private void testFile(Path inPath, Path outPath, DataConverter converter, DecisionForest forest, + Dataset dataset, List<double[]> results, Random rng) throws IOException { // create the predictions file FSDataOutputStream ofile = null; @@ -255,17 +285,14 @@ public class TestForest extends Configur } Instance instance = converter.convert(line); - int prediction = forest.classify(rng, instance); + double prediction = forest.classify(dataset, rng, instance); if (outputPath != null) { - ofile.writeChars(Integer.toString(prediction)); // write the prediction + ofile.writeChars(Double.toString(prediction)); // write the prediction ofile.writeChar('\n'); } - - if (analyzer != null) { - analyzer.addInstance(dataset.getLabelString(dataset.getLabel(instance)), - new ClassifierResult(dataset.getLabelString(prediction), 1.0)); - } + + results.add(new double[] {dataset.getLabel(instance), prediction}); } scanner.close();
