Author: adeneche
Date: Fri Dec 23 07:29:36 2011
New Revision: 1222594
URL: http://svn.apache.org/viewvc?rev=1222594&view=rev
Log:
MAHOUT-840 DecisionTreeBuilder is now the default tree builder
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java?rev=1222594&r1=1222593&r2=1222594&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
Fri Dec 23 07:29:36 2011
@@ -37,19 +37,24 @@ public final class InfiniteRecursionTest
};
/**
- * make sure DefaultTreeBuilder.build() does not throw a
StackOverflowException
+ * make sure DecisionTreeBuilder.build() does not throw a
StackOverflowException
*/
@Test
public void testBuild() throws Exception {
Random rng = RandomUtils.getRandom();
- TreeBuilder builder = new DefaultTreeBuilder();
-
String[] source = Utils.double2String(dData);
String descriptor = "N N N N N N N N L";
+
Dataset dataset = DataLoader.generateDataset(descriptor, false, source);
Data data = DataLoader.loadData(dataset, source);
+ TreeBuilder builder = new DecisionTreeBuilder();
+ builder.build(rng, data);
+ // regression
+ dataset = DataLoader.generateDataset(descriptor, true, source);
+ data = DataLoader.loadData(dataset, source);
+ builder = new DecisionTreeBuilder();
builder.build(rng, data);
}
}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java?rev=1222594&r1=1222593&r2=1222594&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
Fri Dec 23 07:29:36 2011
@@ -36,8 +36,7 @@ import org.apache.hadoop.util.ToolRunner
import org.apache.mahout.common.CommandLineUtil;
import org.apache.mahout.classifier.df.DFUtils;
import org.apache.mahout.classifier.df.DecisionForest;
-import org.apache.mahout.classifier.df.builder.DefaultTreeBuilder;
-import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
import org.apache.mahout.classifier.df.data.Data;
import org.apache.mahout.classifier.df.data.DataLoader;
import org.apache.mahout.classifier.df.data.Dataset;
@@ -59,15 +58,20 @@ public class BuildForest extends Configu
private Path datasetPath;
private Path outputPath;
- private int m; // Number of variables to select at each tree-node
+
+ private Integer m; // Number of variables to select at each tree-node
+
+ private boolean complemented; // tree is complemented
+ private Integer minSplitNum; // minimum number for split
+
+ private Double minVarianceProportion; // minimum proportion of the total
variance for split
+
private int nbTrees; // Number of trees to grow
private Long seed; // Random seed
private boolean isPartial; // use partial data implementation
-
- private String builderName; // Tree builder class name
@Override
public int run(String[] args) throws IOException, ClassNotFoundException,
InterruptedException,
@@ -77,42 +81,57 @@ public class BuildForest extends Configu
ArgumentBuilder abuilder = new ArgumentBuilder();
GroupBuilder gbuilder = new GroupBuilder();
- Option dataOpt =
obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
-
abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data
path").create();
-
- Option datasetOpt =
obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
-
abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset
path")
- .create();
-
- Option selectionOpt =
obuilder.withLongName("selection").withShortName("sl").withRequired(true)
-
.withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create()).withDescription(
- "Number of variables to select randomly at each tree-node").create();
-
- Option seedOpt =
obuilder.withLongName("seed").withShortName("sd").withRequired(false).withArgument(
-
abuilder.withName("seed").withMinimum(1).withMaximum(1).create()).withDescription(
- "Optional, seed value used to initialise the Random number
generator").create();
+ Option dataOpt =
obuilder.withLongName("data").withShortName("d").withRequired(true)
+
.withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+ .withDescription("Data path").create();
+
+ Option datasetOpt =
obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
+
.withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+ .withDescription("Dataset path").create();
+
+ Option selectionOpt =
obuilder.withLongName("selection").withShortName("sl").withRequired(false)
+
.withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create())
+ .withDescription("Optional, Number of variables to select randomly at
each tree-node.\n" +
+ "For classification problem, the default is square root of the number
of explanatory variables.\n" +
+ "For regression problem, the default is 1/3 of the number of
explanatory variables.").create();
+
+ Option noCompleteOpt =
obuilder.withLongName("no-complete").withShortName("nc").withRequired(false)
+ .withDescription("Optional, The tree is not complemented").create();
+
+ Option minSplitOpt =
obuilder.withLongName("minsplit").withShortName("ms").withRequired(false)
+
.withArgument(abuilder.withName("minsplit").withMinimum(1).withMaximum(1).create())
+ .withDescription("Optional, The tree-node is not divided, if the
branching data size is " +
+ "smaller than this value.\nThe default is 2.").create();
+
+ Option minPropOpt =
obuilder.withLongName("minprop").withShortName("mp").withRequired(false)
+
.withArgument(abuilder.withName("minprop").withMinimum(1).withMaximum(1).create())
+ .withDescription("Optional, The tree-node is not divided, if the
proportion of the " +
+ "variance of branching data is smaller than this value.\n" +
+ "In the case of a regression problem, this value is used. " +
+ "The default is 1/1000(0.001).").create();
+
+ Option seedOpt =
obuilder.withLongName("seed").withShortName("sd").withRequired(false)
+
.withArgument(abuilder.withName("seed").withMinimum(1).withMaximum(1).create())
+ .withDescription("Optional, seed value used to initialise the Random
number generator").create();
Option partialOpt =
obuilder.withLongName("partial").withShortName("p").withRequired(false)
.withDescription("Optional, use the Partial Data
implementation").create();
- Option nbtreesOpt =
obuilder.withLongName("nbtrees").withShortName("t").withRequired(true).withArgument(
-
abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription(
- "Number of trees to grow").create();
-
- Option outputOpt =
obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
- abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
- withDescription("Output path, will contain the Decision
Forest").create();
-
- Option builderOpt =
obuilder.withLongName("builder").withShortName("b").withRequired(false)
-
.withArgument(abuilder.withName("builder").withMinimum(1).withMaximum(1).create()).
- withDescription("Tree builder class name").create();
+ Option nbtreesOpt =
obuilder.withLongName("nbtrees").withShortName("t").withRequired(true)
+
.withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
+ .withDescription("Number of trees to grow").create();
+
+ Option outputOpt =
obuilder.withLongName("output").withShortName("o").withRequired(true)
+
.withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+ .withDescription("Output path, will contain the Decision
Forest").create();
- Option helpOpt = obuilder.withLongName("help").withDescription("Print out
help").withShortName("h")
- .create();
+ Option helpOpt = obuilder.withLongName("help").withShortName("h")
+ .withDescription("Print out help").create();
Group group =
gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt)
-
.withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
-
.withOption(outputOpt).withOption(builderOpt).withOption(helpOpt).create();
+
.withOption(selectionOpt).withOption(noCompleteOpt).withOption(minSplitOpt)
+
.withOption(minPropOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
+ .withOption(outputOpt).withOption(helpOpt).create();
try {
Parser parser = new Parser();
@@ -128,28 +147,39 @@ public class BuildForest extends Configu
String dataName = cmdLine.getValue(dataOpt).toString();
String datasetName = cmdLine.getValue(datasetOpt).toString();
String outputName = cmdLine.getValue(outputOpt).toString();
- m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
+ if (cmdLine.hasOption(selectionOpt)) {
+ m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
+ }
+ if (cmdLine.hasOption(noCompleteOpt)) {
+ complemented = false;
+ } else {
+ complemented = true;
+ }
+ if (cmdLine.hasOption(minSplitOpt)) {
+ minSplitNum =
Integer.parseInt(cmdLine.getValue(minSplitOpt).toString());
+ }
+ if (cmdLine.hasOption(minPropOpt)) {
+ minVarianceProportion =
Double.parseDouble(cmdLine.getValue(minPropOpt).toString());
+ }
if (cmdLine.hasOption(seedOpt)) {
seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
}
- if (cmdLine.hasOption(builderOpt)) {
- builderName = cmdLine.getValue(builderOpt).toString();
- }
-
if (log.isDebugEnabled()) {
log.debug("data : {}", dataName);
log.debug("dataset : {}", datasetName);
log.debug("output : {}", outputName);
log.debug("m : {}", m);
+ log.debug("complemented : {}", complemented);
+ log.debug("minSplitNum : {}", minSplitNum);
+ log.debug("minVarianceProportion : {}", minVarianceProportion);
log.debug("seed : {}", seed);
log.debug("nbtrees : {}", nbTrees);
log.debug("isPartial : {}", isPartial);
- log.debug("builder : {}", builderName);
}
-
+
dataPath = new Path(dataName);
datasetPath = new Path(datasetName);
outputPath = new Path(outputName);
@@ -174,13 +204,16 @@ public class BuildForest extends Configu
return;
}
- TreeBuilder treeBuilder;
- if (builderName == null) {
- treeBuilder = new DefaultTreeBuilder();
- ((DefaultTreeBuilder) treeBuilder).setM(m);
- } else {
- Class<?> clazz = Class.forName(builderName);
- treeBuilder = (TreeBuilder) clazz.newInstance();
+ DecisionTreeBuilder treeBuilder = new DecisionTreeBuilder();
+ if (m != null) {
+ treeBuilder.setM(m);
+ }
+ treeBuilder.setComplemented(complemented);
+ if (minSplitNum != null) {
+ treeBuilder.setMinSplitNum(minSplitNum);
+ }
+ if (minVarianceProportion != null) {
+ treeBuilder.setMinVarianceProportion(minVarianceProportion);
}
Builder forestBuilder;