Author: adeneche
Date: Tue Oct 11 04:52:36 2011
New Revision: 1181625
URL: http://svn.apache.org/viewvc?rev=1181625&view=rev
Log:
MAHOUT-835 removed all code related to OOB computation
Removed:
mahout/trunk/core/src/main/java/org/apache/mahout/df/callback/
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/InterResults.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step0Job.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step2Job.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step2Mapper.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/InterResultsTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartitionBugTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step0JobTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step2MapperTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/PartialBuilder.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step1Mapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/df/ref/SequentialBuilder.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
Modified: mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java (original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/Bagging.java Tue Oct
11 04:52:36 2011
@@ -21,7 +21,6 @@ import java.util.Arrays;
import java.util.Random;
import org.apache.mahout.df.builder.TreeBuilder;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.Instance;
import org.apache.mahout.df.node.Node;
@@ -53,27 +52,13 @@ public class Bagging {
* @param treeId
* tree identifier
*/
- public Node build(int treeId, Random rng, PredictionCallback callback) {
+ public Node build(int treeId, Random rng) {
log.debug("Bagging...");
Arrays.fill(sampled, false);
Data bag = data.bagging(rng, sampled);
log.debug("Building...");
- Node tree = treeBuilder.build(rng, bag);
-
- // predict the label for the out-of-bag elements
- if (callback != null) {
- log.debug("Oob error estimation");
- for (int index = 0; index < data.size(); index++) {
- if (!sampled[index]) {
- Instance instance = data.get(index);
- int prediction = tree.classify(instance);
- callback.prediction(treeId, instance.getId(), prediction);
- }
- }
- }
-
- return tree;
+ return treeBuilder.build(rng, bag);
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/DecisionForest.java
Tue Oct 11 04:52:36 2011
@@ -25,7 +25,6 @@ import java.io.DataInput;
import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataUtils;
import org.apache.mahout.df.data.Instance;
@@ -62,8 +61,8 @@ public class DecisionForest implements W
/**
* Classifies the data and calls callback for each classification
*/
- public void classify(Data data, PredictionCallback callback) {
- Preconditions.checkArgument(callback != null, "callback must not be null");
+ public void classify(Data data, int[] predictions) {
+ Preconditions.checkArgument(data.size() == predictions.length,
"predictions.length must be equal to data.size()");
if (data.isEmpty()) {
return; // nothing to classify
@@ -73,8 +72,7 @@ public class DecisionForest implements W
Node tree = trees.get(treeId);
for (int index = 0; index < data.size(); index++) {
- int prediction = tree.classify(data.get(index));
- callback.prediction(treeId, index, prediction);
+ predictions[index] = tree.classify(data.get(index));
}
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/Builder.java
Tue Oct 11 04:52:36 2011
@@ -32,7 +32,6 @@ import org.apache.mahout.common.HadoopUt
import org.apache.mahout.common.StringUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.builder.TreeBuilder;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Dataset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -108,14 +107,6 @@ public abstract class Builder {
return conf.getBoolean("debug.mahout.rf.output", true);
}
- protected static boolean isOobEstimate(Configuration conf) {
- return conf.getBoolean("mahout.rf.oob", false);
- }
-
- private static void setOobEstimate(Configuration conf, boolean value) {
- conf.setBoolean("mahout.rf.oob", value);
- }
-
/**
* Returns the random seed
*
@@ -254,12 +245,10 @@ public abstract class Builder {
* Hadoop's Job
* @param nbTrees
* number of trees to grow
- * @param oobEstimate
- * true, if oob error should be estimated
* @throws IOException
* if anything goes wrong while configuring the job
*/
- protected abstract void configureJob(Job job, int nbTrees, boolean
oobEstimate) throws IOException;
+ protected abstract void configureJob(Job job, int nbTrees) throws
IOException;
/**
* Sequential implementation should override this method to simulate the job
execution
@@ -277,16 +266,14 @@ public abstract class Builder {
*
* @param job
* Hadoop's job
- * @param callback
- * can be null
* @return Built DecisionForest
* @throws IOException
* if anything goes wrong while parsing the output
*/
- protected abstract DecisionForest parseOutput(Job job, PredictionCallback
callback)
+ protected abstract DecisionForest parseOutput(Job job)
throws IOException, ClassNotFoundException, InterruptedException;
- public DecisionForest build(int nbTrees, PredictionCallback callback)
+ public DecisionForest build(int nbTrees)
throws IOException, ClassNotFoundException, InterruptedException {
// int numTrees = getNbTrees(conf);
@@ -303,7 +290,6 @@ public abstract class Builder {
}
setNbTrees(conf, nbTrees);
setTreeBuilder(conf, treeBuilder);
- setOobEstimate(conf, callback != null);
// put the dataset into the DistributedCache
DistributedCache.addCacheFile(datasetPath.toUri(), conf);
@@ -311,7 +297,7 @@ public abstract class Builder {
Job job = new Job(conf, "decision forest builder");
log.debug("Configuring the job...");
- configureJob(job, nbTrees, callback != null);
+ configureJob(job, nbTrees);
log.debug("Running the job...");
if (!runJob(job)) {
@@ -321,7 +307,7 @@ public abstract class Builder {
if (isOutput(conf)) {
log.debug("Parsing the output...");
- DecisionForest forest = parseOutput(job, callback);
+ DecisionForest forest = parseOutput(job);
HadoopUtil.delete(conf, outputPath);
return forest;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/MapredMapper.java
Tue Oct 11 04:52:36 2011
@@ -33,22 +33,12 @@ public class MapredMapper<KEYIN,VALUEIN,
private boolean noOutput;
- private boolean oobEstimate;
-
private TreeBuilder treeBuilder;
private Dataset dataset;
/**
*
- * @return if false, the mapper does not output
- */
- protected boolean isOobEstimate() {
- return oobEstimate;
- }
-
- /**
- *
* @return if false, the mapper does not estimate and output predictions
*/
protected boolean isNoOutput() {
@@ -69,17 +59,16 @@ public class MapredMapper<KEYIN,VALUEIN,
Configuration conf = context.getConfiguration();
- configure(!Builder.isOutput(conf), Builder.isOobEstimate(conf),
Builder.getTreeBuilder(conf), Builder
+ configure(!Builder.isOutput(conf), Builder.getTreeBuilder(conf), Builder
.loadDataset(conf));
}
/**
* Useful for testing
*/
- protected void configure(boolean noOutput, boolean oobEstimate, TreeBuilder
treeBuilder, Dataset dataset) {
+ protected void configure(boolean noOutput, TreeBuilder treeBuilder, Dataset
dataset) {
Preconditions.checkArgument(treeBuilder != null, "TreeBuilder not found in
the Job parameters");
this.noOutput = noOutput;
- this.oobEstimate = oobEstimate;
this.treeBuilder = treeBuilder;
this.dataset = dataset;
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemBuilder.java
Tue Oct 11 04:52:36 2011
@@ -21,8 +21,6 @@ import java.io.IOException;
import java.util.List;
import java.util.Map;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.filecache.DistributedCache;
import org.apache.hadoop.fs.FileSystem;
@@ -36,11 +34,13 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.builder.TreeBuilder;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.mapreduce.Builder;
import org.apache.mahout.df.mapreduce.MapredOutput;
import org.apache.mahout.df.node.Node;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+
/**
* MapReduce implementation where each mapper loads a full copy of the data
in-memory. The forest trees are
* splitted across all the mappers
@@ -56,7 +56,7 @@ public class InMemBuilder extends Builde
}
@Override
- protected void configureJob(Job job, int nbTrees, boolean oobEstimate)
throws IOException {
+ protected void configureJob(Job job, int nbTrees) throws IOException {
Configuration conf = job.getConfiguration();
job.setJarByClass(InMemBuilder.class);
@@ -78,7 +78,7 @@ public class InMemBuilder extends Builde
}
@Override
- protected DecisionForest parseOutput(Job job, PredictionCallback callback)
throws IOException {
+ protected DecisionForest parseOutput(Job job) throws IOException {
Configuration conf = job.getConfiguration();
Map<Integer,MapredOutput> output = Maps.newHashMap();
@@ -95,26 +95,18 @@ public class InMemBuilder extends Builde
}
}
- return processOutput(output, callback);
+ return processOutput(output);
}
/**
- * Process the output, extracting the trees and passing the predictions to
the callback
+ * Process the output, extracting the trees
*/
- private static DecisionForest processOutput(Map<Integer,MapredOutput>
output, PredictionCallback callback) {
+ private static DecisionForest processOutput(Map<Integer,MapredOutput>
output) {
List<Node> trees = Lists.newArrayList();
for (Map.Entry<Integer,MapredOutput> entry : output.entrySet()) {
MapredOutput value = entry.getValue();
-
trees.add(value.getTree());
-
- if (callback != null) {
- int[] predictions = value.getPredictions();
- for (int index = 0; index < predictions.length; index++) {
- callback.prediction(entry.getKey(), index, predictions[index]);
- }
- }
}
return new DecisionForest(trees);
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/inmem/InMemMapper.java
Tue Oct 11 04:52:36 2011
@@ -27,7 +27,6 @@ import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.NullWritable;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.Bagging;
-import org.apache.mahout.df.callback.SingleTreePredictions;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
@@ -84,22 +83,14 @@ public class InMemMapper extends MapredM
protected void map(IntWritable key, Context context) throws IOException,
InterruptedException {
- SingleTreePredictions callback = null;
- int[] predictions = null;
-
- if (isOobEstimate() && !isNoOutput()) {
- callback = new SingleTreePredictions(data.size());
- predictions = callback.getPredictions();
- }
-
initRandom((InMemInputSplit) context.getInputSplit());
log.debug("Building...");
- Node tree = bagging.build(key.get(), rng, callback);
+ Node tree = bagging.build(key.get(), rng);
if (!isNoOutput()) {
log.debug("Outputing...");
- MapredOutput mrOut = new MapredOutput(tree, predictions);
+ MapredOutput mrOut = new MapredOutput(tree);
context.write(key, mrOut);
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/PartialBuilder.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/PartialBuilder.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/PartialBuilder.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/PartialBuilder.java
Tue Oct 11 04:52:36 2011
@@ -34,10 +34,8 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.builder.TreeBuilder;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.mapreduce.Builder;
import org.apache.mahout.df.mapreduce.MapredOutput;
-import org.apache.mahout.df.mapreduce.partial.Step0Job.Step0Output;
import org.apache.mahout.df.node.Node;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@@ -64,14 +62,6 @@ public class PartialBuilder extends Buil
}
/**
- * Indicates if we should run the second step of the builder.<br>
- * This parameter is only meant for debuging, so we keep it protected.
- */
- protected static boolean isStep2(Configuration conf) {
- return conf.getBoolean("debug.mahout.rf.partial.step2", true);
- }
-
- /**
* Should run the second step of the builder ?
*
* @param value
@@ -83,7 +73,7 @@ public class PartialBuilder extends Buil
}
@Override
- protected void configureJob(Job job, int nbTrees, boolean oobEstimate)
throws IOException {
+ protected void configureJob(Job job, int nbTrees) throws IOException {
Configuration conf = job.getConfiguration();
job.setJarByClass(PartialBuilder.class);
@@ -102,7 +92,7 @@ public class PartialBuilder extends Buil
}
@Override
- protected DecisionForest parseOutput(Job job, PredictionCallback callback)
+ protected DecisionForest parseOutput(Job job)
throws IOException, ClassNotFoundException, InterruptedException {
Configuration conf = job.getConfiguration();
@@ -113,31 +103,8 @@ public class PartialBuilder extends Buil
int[] firstIds = null;
TreeID[] keys = new TreeID[numTrees];
Node[] trees = new Node[numTrees];
- Step0Output[] partitions = null;
- int numMaps = 0;
-
- if (callback != null) {
- log.info("Computing partitions' first ids...");
- Step0Job step0 = new Step0Job(getOutputPath(conf), getDataPath(),
getDatasetPath());
- partitions = step0.run(new Configuration(conf));
-
- log.info("Processing the output...");
- firstIds = Step0Output.extractFirstIds(partitions);
-
- numMaps = partitions.length;
- }
-
- processOutput(job, outputPath, firstIds, keys, trees, callback);
-
- // call the second step in order to complete the oob predictions
- if (callback != null && numMaps > 1 && isStep2(conf)) {
- log.info("*****************************");
- log.info("Second Step");
- log.info("*****************************");
- Step2Job step2 = new Step2Job(getOutputPath(conf), getDataPath(),
getDatasetPath(), partitions);
-
- step2.run(new Configuration(conf), keys, trees, callback);
- }
+
+ processOutput(job, outputPath, firstIds, keys, trees);
return new DecisionForest(Arrays.asList(trees));
}
@@ -154,15 +121,12 @@ public class PartialBuilder extends Buil
* can be null
* @param trees
* can be null
- * @param callback
- * can be null
*/
protected static void processOutput(JobContext job,
Path outputPath,
int[] firstIds,
TreeID[] keys,
- Node[] trees,
- PredictionCallback callback) throws
IOException {
+ Node[] trees) throws IOException {
Preconditions.checkArgument(keys == null && trees == null || keys != null
&& trees != null,
"if keys is null, trees should also be null");
Preconditions.checkArgument(keys == null || keys.length == trees.length,
"keys.length != trees.length");
@@ -185,9 +149,6 @@ public class PartialBuilder extends Buil
if (trees != null) {
trees[index] = value.getTree();
}
- if (callback != null) {
- processOutput(firstIds, key, value, callback);
- }
index++;
}
}
@@ -197,22 +158,4 @@ public class PartialBuilder extends Buil
throw new IllegalStateException("Some key/values are missing from the
output");
}
}
-
- /**
- * Process the output, extracting the trees and passing the predictions to
the callback
- *
- * @param firstIds
- * partitions' first ids in hadoop's order
- */
- private static void processOutput(int[] firstIds,
- TreeID key,
- MapredOutput value,
- PredictionCallback callback) {
-
- int[] predictions = value.getPredictions();
-
- for (int instanceId = 0; instanceId < predictions.length; instanceId++) {
- callback.prediction(key.treeId(), firstIds[key.partition()] +
instanceId, predictions[instanceId]);
- }
- }
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step1Mapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step1Mapper.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step1Mapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/df/mapreduce/partial/Step1Mapper.java
Tue Oct 11 04:52:36 2011
@@ -21,13 +21,11 @@ import java.io.IOException;
import java.util.List;
import java.util.Random;
-import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.Bagging;
-import org.apache.mahout.df.callback.SingleTreePredictions;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataConverter;
import org.apache.mahout.df.data.Instance;
@@ -39,6 +37,7 @@ import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
/**
* First step of the Partial Data Builder. Builds the trees using the data
available in the InputSplit.
@@ -157,21 +156,15 @@ public class Step1Mapper extends MapredM
TreeID key = new TreeID();
log.debug("Building {} trees", nbTrees);
- SingleTreePredictions callback = null;
- int[] predictions = null;
for (int treeId = 0; treeId < nbTrees; treeId++) {
log.debug("Building tree number : {}", treeId);
- if (isOobEstimate() && !isNoOutput()) {
- callback = new SingleTreePredictions(data.size());
- predictions = callback.getPredictions();
- }
- Node tree = bagging.build(treeId, rng, callback);
+ Node tree = bagging.build(treeId, rng);
key.set(partition, firstTreeId + treeId);
if (!isNoOutput()) {
- MapredOutput emOut = new MapredOutput(tree, predictions);
+ MapredOutput emOut = new MapredOutput(tree);
context.write(key, emOut);
}
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/df/ref/SequentialBuilder.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/df/ref/SequentialBuilder.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/df/ref/SequentialBuilder.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/df/ref/SequentialBuilder.java
Tue Oct 11 04:52:36 2011
@@ -24,7 +24,6 @@ import com.google.common.collect.Lists;
import org.apache.mahout.df.Bagging;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.builder.TreeBuilder;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.node.Node;
import org.slf4j.Logger;
@@ -56,11 +55,11 @@ public class SequentialBuilder {
bagging = new Bagging(treeBuilder, data);
}
- public DecisionForest build(int nbTrees, PredictionCallback callback) {
+ public DecisionForest build(int nbTrees) {
List<Node> trees = Lists.newArrayList();
for (int treeId = 0; treeId < nbTrees; treeId++) {
- trees.add(bagging.build(treeId, rng, callback));
+ trees.add(bagging.build(treeId, rng));
logProgress(((float) treeId + 1) / nbTrees);
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialBuilderTest.java
Tue Oct 11 04:52:36 2011
@@ -35,7 +35,6 @@ import org.apache.mahout.common.MahoutTe
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.builder.DefaultTreeBuilder;
import org.apache.mahout.df.builder.TreeBuilder;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.mapreduce.MapredOutput;
import org.apache.mahout.df.node.Leaf;
import org.apache.mahout.df.node.Node;
@@ -83,8 +82,7 @@ public final class PartialBuilderTest ex
TreeID[] newKeys = new TreeID[NUM_TREES];
Node[] newTrees = new Node[NUM_TREES];
- PartialBuilder.processOutput(new Job(conf), base, firstIds, newKeys,
newTrees,
- new TestCallback(keys, values));
+ PartialBuilder.processOutput(new Job(conf), base, firstIds, newKeys,
newTrees);
// check the forest
for (int tree = 0; tree < NUM_TREES; tree++) {
@@ -188,7 +186,6 @@ public final class PartialBuilderTest ex
assertEquals(NUM_TREES, getNbTrees(conf));
assertFalse(isOutput(conf));
- assertTrue(isOobEstimate(conf));
assertEquals(treeBuilder, getTreeBuilder(conf));
@@ -198,32 +195,4 @@ public final class PartialBuilderTest ex
}
}
-
- /**
- * Mock Callback. Make sure that the callback receives the correct
predictions
- *
- */
- static class TestCallback implements PredictionCallback {
-
- private final TreeID[] keys;
-
- private final MapredOutput[] values;
-
- TestCallback(TreeID[] keys, MapredOutput[] values) {
- this.keys = keys;
- this.values = values;
- }
-
- @Override
- public void prediction(int treeId, int instanceId, int prediction) {
- int partition = instanceId / NUM_INSTANCES;
-
- TreeID key = new TreeID(partition, treeId);
- int index = ArrayUtils.indexOf(keys, key);
- assertTrue("key not found", index >= 0);
-
- assertEquals(values[index].getPredictions()[instanceId % NUM_INSTANCES],
prediction);
- }
-
- }
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/PartialSequentialBuilder.java
Tue Oct 11 04:52:36 2011
@@ -20,17 +20,13 @@ package org.apache.mahout.df.mapreduce.p
import java.io.IOException;
import java.util.List;
-import com.google.common.collect.Lists;
import org.apache.commons.lang.ArrayUtils;
import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
-import org.apache.hadoop.mapreduce.JobContext;
-import org.apache.hadoop.mapreduce.JobID;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.mapreduce.TaskAttemptID;
@@ -38,7 +34,6 @@ import org.apache.hadoop.mapreduce.lib.i
import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.DecisionForest;
import org.apache.mahout.df.builder.TreeBuilder;
-import org.apache.mahout.df.callback.PredictionCallback;
import org.apache.mahout.df.data.Dataset;
import org.apache.mahout.df.mapreduce.Builder;
import org.apache.mahout.df.mapreduce.MapredOutput;
@@ -46,6 +41,8 @@ import org.apache.mahout.df.node.Node;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
+import com.google.common.collect.Lists;
+
/**
* Simulates the Partial mapreduce implementation in a sequential manner. Must
* receive a seed
@@ -56,8 +53,6 @@ public class PartialSequentialBuilder ex
private MockContext firstOutput;
- private MockContext secondOutput;
-
private final Dataset dataset;
/** first instance id in hadoop's order */
@@ -78,13 +73,13 @@ public class PartialSequentialBuilder ex
}
@Override
- protected void configureJob(Job job, int nbTrees, boolean oobEstimate)
+ protected void configureJob(Job job, int nbTrees)
throws IOException {
Configuration conf = job.getConfiguration();
int num = conf.getInt("mapred.map.tasks", -1);
- super.configureJob(job, nbTrees, oobEstimate);
+ super.configureJob(job, nbTrees);
// PartialBuilder sets the number of maps to 1 if we are running in 'local'
conf.setInt("mapred.map.tasks", num);
@@ -151,128 +146,32 @@ public class PartialSequentialBuilder ex
}
@Override
- protected DecisionForest parseOutput(Job job, PredictionCallback callback)
throws IOException, InterruptedException {
- Configuration conf = job.getConfiguration();
-
- DecisionForest forest = processOutput(firstOutput.getKeys(),
firstOutput.getValues(), callback);
-
- if (isStep2(conf)) {
- Path forestPath = new Path(getOutputPath(conf), "step1.inter");
- FileSystem fs = forestPath.getFileSystem(conf);
-
- Node[] trees = new Node[forest.getTrees().size()];
- forest.getTrees().toArray(trees);
- InterResults.store(fs, forestPath, firstOutput.getKeys(), trees, sizes);
-
- log.info("***********");
- log.info("Second Step");
- log.info("***********");
- secondStep(conf, forestPath, callback);
-
- processOutput(secondOutput.getKeys(), secondOutput.getValues(),
callback);
- }
-
- return forest;
+ protected DecisionForest parseOutput(Job job) throws IOException,
InterruptedException {
+ return processOutput(firstOutput.getKeys(), firstOutput.getValues());
}
/**
- * extract the decision forest and call the callback after correcting the
instance ids
+ * extract the decision forest
*/
- protected DecisionForest processOutput(TreeID[] keys, MapredOutput[] values,
PredictionCallback callback) {
+ protected DecisionForest processOutput(TreeID[] keys, MapredOutput[] values)
{
List<Node> trees = Lists.newArrayList();
for (int index = 0; index < keys.length; index++) {
- TreeID key = keys[index];
MapredOutput value = values[index];
-
trees.add(value.getTree());
-
- int[] predictions = value.getPredictions();
- for (int id = 0; id < predictions.length; id++) {
- callback.prediction(key.treeId(), firstIds[key.partition()] + id,
- predictions[id]);
- }
}
return new DecisionForest(trees);
}
/**
- * The second step uses the trees to predict the rest of the instances
outside
- * their own partition
- */
- protected void secondStep(Configuration conf, Path forestPath,
PredictionCallback callback)
- throws IOException, InterruptedException {
- JobContext jobContext = new JobContext(conf, new JobID());
-
- // retrieve the splits
- TextInputFormat input = new TextInputFormat();
- List<InputSplit> splits = input.getSplits(jobContext);
-
- int nbSplits = splits.size();
- log.debug("Nb splits : {}", nbSplits);
-
- InputSplit[] sorted = new InputSplit[nbSplits];
- splits.toArray(sorted);
- Builder.sortSplits(sorted);
-
- int numTrees = Builder.getNbTrees(conf); // total number of trees
-
- // compute the expected number of outputs
- int total = 0;
- for (int p = 0; p < nbSplits; p++) {
- total += Step2Mapper.nbConcerned(nbSplits, numTrees, p);
- }
-
- TaskAttemptContext task = new TaskAttemptContext(conf, new
TaskAttemptID());
-
- secondOutput = new MockContext(new Step2Mapper(), conf,
task.getTaskAttemptID(), numTrees);
- long slowest = 0; // duration of slowest map
-
- for (int partition = 0; partition < nbSplits; partition++) {
-
- InputSplit split = sorted[partition];
- RecordReader<LongWritable, Text> reader =
input.createRecordReader(split, task);
-
- // load the output of the 1st step
- int nbConcerned = Step2Mapper.nbConcerned(nbSplits, numTrees, partition);
- TreeID[] fsKeys = new TreeID[nbConcerned];
- Node[] fsTrees = new Node[nbConcerned];
-
- FileSystem fs = forestPath.getFileSystem(conf);
- int numInstances = InterResults.load(fs, forestPath, nbSplits,
- numTrees, partition, fsKeys, fsTrees);
-
- Step2Mapper mapper = new Step2Mapper();
- mapper.configure(partition, dataset, fsKeys, fsTrees, numInstances);
-
- long time = System.currentTimeMillis();
-
- while (reader.nextKeyValue()) {
- mapper.map(reader.getCurrentKey(), reader.getCurrentValue(),
secondOutput);
- }
-
- mapper.cleanup(secondOutput);
-
- time = System.currentTimeMillis() - time;
- log.info("Duration : {}", DFUtils.elapsedTime(time));
-
- if (time > slowest) {
- slowest = time;
- }
- }
-
- log.info("Longest duration : {}", DFUtils.elapsedTime(slowest));
- }
-
- /**
* Special Step1Mapper that can be configured without using a Configuration
*
*/
private static class MockStep1Mapper extends Step1Mapper {
protected MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long
seed,
int partition, int numMapTasks, int numTrees) {
- configure(false, true, treeBuilder, dataset);
+ configure(false, treeBuilder, dataset);
configure(seed, partition, numMapTasks, numTrees);
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/df/mapreduce/partial/Step1MapperTest.java
Tue Oct 11 04:52:36 2011
@@ -66,7 +66,7 @@ public final class Step1MapperTest exten
private static class MockStep1Mapper extends Step1Mapper {
private MockStep1Mapper(TreeBuilder treeBuilder, Dataset dataset, Long
seed,
int partition, int numMapTasks, int numTrees) {
- configure(false, true, treeBuilder, dataset);
+ configure(false, treeBuilder, dataset);
configure(seed, partition, numMapTasks, numTrees);
}
}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/df/BreimanExample.java
Tue Oct 11 04:52:36 2011
@@ -37,9 +37,6 @@ import org.apache.hadoop.util.ToolRunner
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.builder.DefaultTreeBuilder;
-import org.apache.mahout.df.callback.ForestPredictions;
-import org.apache.mahout.df.callback.MeanTreeCollector;
-import org.apache.mahout.df.callback.MultiCallback;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
@@ -59,12 +56,6 @@ public class BreimanExample extends Conf
/** sum test error */
private double sumTestErr;
- /** sum mean tree error */
- private double sumTreeErr;
-
- /** sum test error with m=1 */
- private double sumOneErr;
-
/** mean time to build a forest with m=log2(M)+1 */
private long sumTimeM;
@@ -93,65 +84,38 @@ public class BreimanExample extends Conf
*/
private void runIteration(Random rng, Data data, int m, int nbtrees) {
- int nblabels = data.getDataset().nblabels();
-
log.info("Splitting the data");
Data train = data.clone();
Data test = train.rsplit(rng, (int) (data.size() * 0.1));
- int[] labels = data.extractLabels();
- int[] testLabels = test.extractLabels();
-
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
SequentialBuilder forestBuilder = new SequentialBuilder(rng, treeBuilder,
train);
// grow a forest with m = log2(M)+1
- ForestPredictions errorM = new ForestPredictions(data.size(), nblabels);
// oob error when using m =
-
// log2(M)+1
treeBuilder.setM(m);
long time = System.currentTimeMillis();
log.info("Growing a forest with m={}", m);
- DecisionForest forestM = forestBuilder.build(nbtrees, errorM);
+ DecisionForest forestM = forestBuilder.build(nbtrees);
sumTimeM += System.currentTimeMillis() - time;
numNodesM += forestM.nbNodes();
- double oobM = ErrorEstimate.errorRate(labels,
errorM.computePredictions(rng)); // oob error estimate
-
// when m = log2(M)+1
-
// grow a forest with m=1
- ForestPredictions errorOne = new ForestPredictions(data.size(), nblabels);
// oob error when using m = 1
treeBuilder.setM(1);
time = System.currentTimeMillis();
log.info("Growing a forest with m=1");
- DecisionForest forestOne = forestBuilder.build(nbtrees, errorOne);
+ DecisionForest forestOne = forestBuilder.build(nbtrees);
sumTimeOne += System.currentTimeMillis() - time;
numNodesOne += forestOne.nbNodes();
- double oobOne = ErrorEstimate.errorRate(labels,
errorOne.computePredictions(rng)); // oob error
-
// estimate when m
-
// = 1
-
// compute the test set error (Selection Error), and mean tree error (One
Tree Error),
- // using the lowest oob error forest
- ForestPredictions testError = new ForestPredictions(test.size(),
nblabels); // test set error
- MeanTreeCollector treeError = new MeanTreeCollector(test, nbtrees); //
mean tree error
-
- // compute the test set error using m=1 (Single Input Error)
- errorOne = new ForestPredictions(test.size(), nblabels);
-
- if (oobM < oobOne) {
- forestM.classify(test, new MultiCallback(testError, treeError));
- forestOne.classify(test, errorOne);
- } else {
- forestOne.classify(test, new MultiCallback(testError, treeError,
errorOne));
- }
+ int[] testLabels = test.extractLabels();
+ int[] predictions = new int[test.size()];
+ forestM.classify(test, predictions);
- sumTestErr += ErrorEstimate.errorRate(testLabels,
testError.computePredictions(rng));
- sumOneErr += ErrorEstimate.errorRate(testLabels,
errorOne.computePredictions(rng));
- sumTreeErr += treeError.meanTreeError();
+ sumTestErr += ErrorEstimate.errorRate(testLabels, predictions);
}
public static void main(String[] args) throws Exception {
@@ -231,8 +195,6 @@ public class BreimanExample extends Conf
log.info("********************************************");
log.info("Selection error : {}", sumTestErr / nbIterations);
- log.info("Single Input error : {}", sumOneErr / nbIterations);
- log.info("One Tree error : {}", sumTreeErr / nbIterations);
log.info("Mean Random Input Time : {}", DFUtils.elapsedTime(sumTimeM /
nbIterations));
log.info("Mean Single Input Time : {}", DFUtils.elapsedTime(sumTimeOne /
nbIterations));
log.info("Mean Random Input Num Nodes : {}", numNodesM / nbIterations);
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java?rev=1181625&r1=1181624&r2=1181625&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/df/mapreduce/BuildForest.java
Tue Oct 11 04:52:36 2011
@@ -18,7 +18,6 @@
package org.apache.mahout.df.mapreduce;
import java.io.IOException;
-import java.util.Random;
import org.apache.commons.cli2.CommandLine;
import org.apache.commons.cli2.Group;
@@ -35,12 +34,9 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.common.CommandLineUtil;
-import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.df.DFUtils;
import org.apache.mahout.df.DecisionForest;
-import org.apache.mahout.df.ErrorEstimate;
import org.apache.mahout.df.builder.DefaultTreeBuilder;
-import org.apache.mahout.df.callback.ForestPredictions;
import org.apache.mahout.df.data.Data;
import org.apache.mahout.df.data.DataLoader;
import org.apache.mahout.df.data.Dataset;
@@ -69,8 +65,6 @@ public class BuildForest extends Configu
private Long seed; // Random seed
private boolean isPartial; // use partial data implementation
-
- private boolean isOob; // estimate oob error;
@Override
public int run(String[] args) throws IOException, ClassNotFoundException,
InterruptedException {
@@ -79,9 +73,6 @@ public class BuildForest extends Configu
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
- Option oobOpt =
obuilder.withShortName("oob").withRequired(false).withDescription(
- "Optional, estimate the out-of-bag error").create();
-
Option dataOpt =
obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data
path").create();
@@ -111,7 +102,7 @@ public class BuildForest extends Configu
Option helpOpt = obuilder.withLongName("help").withDescription("Print out
help").withShortName("h")
.create();
- Group group =
gbuilder.withName("Options").withOption(oobOpt).withOption(dataOpt).withOption(datasetOpt)
+ Group group =
gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt)
.withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
.withOption(outputOpt).withOption(helpOpt).create();
@@ -126,7 +117,6 @@ public class BuildForest extends Configu
}
isPartial = cmdLine.hasOption(partialOpt);
- isOob = cmdLine.hasOption(oobOpt);
String dataName = cmdLine.getValue(dataOpt).toString();
String datasetName = cmdLine.getValue(datasetOpt).toString();
String outputName = cmdLine.getValue(outputOpt).toString();
@@ -144,8 +134,7 @@ public class BuildForest extends Configu
log.debug("seed : {}", seed);
log.debug("nbtrees : {}", nbTrees);
log.debug("isPartial : {}", isPartial);
- log.debug("isOob : {}", isOob);
-
+
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
outputPath = new Path(outputName);
@@ -172,11 +161,6 @@ public class BuildForest extends Configu
DefaultTreeBuilder treeBuilder = new DefaultTreeBuilder();
treeBuilder.setM(m);
- Dataset dataset = Dataset.load(getConf(), datasetPath);
-
- ForestPredictions callback = isOob ? new
ForestPredictions(dataset.nbInstances(), dataset.nblabels())
- : null;
-
Builder forestBuilder;
if (isPartial) {
@@ -192,7 +176,7 @@ public class BuildForest extends Configu
log.info("Building the forest...");
long time = System.currentTimeMillis();
- DecisionForest forest = forestBuilder.build(nbTrees, callback);
+ DecisionForest forest = forestBuilder.build(nbTrees);
time = System.currentTimeMillis() - time;
log.info("Build Time: {}", DFUtils.elapsedTime(time));
@@ -200,26 +184,10 @@ public class BuildForest extends Configu
log.info("Forest mean num Nodes: {}", forest.meanNbNodes());
log.info("Forest mean max Depth: {}", forest.meanMaxDepth());
- if (isOob) {
- Random rng;
- if (seed != null) {
- rng = RandomUtils.getRandom(seed);
- } else {
- rng = RandomUtils.getRandom();
- }
-
- FileSystem fs = dataPath.getFileSystem(getConf());
- int[] labels = Data.extractLabels(dataset, fs, dataPath);
-
- log.info("oob error estimate : "
- + ErrorEstimate.errorRate(labels,
callback.computePredictions(rng)));
- }
-
// store the decision forest in the output path
Path forestPath = new Path(outputPath, "forest.seq");
log.info("Storing the forest in: " + forestPath);
DFUtils.storeWritable(getConf(), forestPath, forest);
-
}
protected static Data loadData(Configuration conf, Path dataPath, Dataset
dataset) throws IOException {