Author: adeneche
Date: Fri Dec 23 07:29:36 2011
New Revision: 1222594

URL: http://svn.apache.org/viewvc?rev=1222594&view=rev
Log:
MAHOUT-840 DecisionTreeBuilder is now the default tree builder

Modified:
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java?rev=1222594&r1=1222593&r2=1222594&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/df/builder/InfiniteRecursionTest.java
 Fri Dec 23 07:29:36 2011
@@ -37,19 +37,24 @@ public final class InfiniteRecursionTest
   };
 
   /**
-   * make sure DefaultTreeBuilder.build() does not throw a 
StackOverflowException
+   * make sure DecisionTreeBuilder.build() does not throw a 
StackOverflowException
    */
   @Test
   public void testBuild() throws Exception {
     Random rng = RandomUtils.getRandom();
 
-    TreeBuilder builder = new DefaultTreeBuilder();
-
     String[] source = Utils.double2String(dData);
     String descriptor = "N N N N N N N N L";
+
     Dataset dataset = DataLoader.generateDataset(descriptor, false, source);
     Data data = DataLoader.loadData(dataset, source);
+    TreeBuilder builder = new DecisionTreeBuilder();
+    builder.build(rng, data);
 
+    // regression
+    dataset = DataLoader.generateDataset(descriptor, true, source);
+    data = DataLoader.loadData(dataset, source);
+    builder = new DecisionTreeBuilder();
     builder.build(rng, data);
   }
 }

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java?rev=1222594&r1=1222593&r2=1222594&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/df/mapreduce/BuildForest.java
 Fri Dec 23 07:29:36 2011
@@ -36,8 +36,7 @@ import org.apache.hadoop.util.ToolRunner
 import org.apache.mahout.common.CommandLineUtil;
 import org.apache.mahout.classifier.df.DFUtils;
 import org.apache.mahout.classifier.df.DecisionForest;
-import org.apache.mahout.classifier.df.builder.DefaultTreeBuilder;
-import org.apache.mahout.classifier.df.builder.TreeBuilder;
+import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder;
 import org.apache.mahout.classifier.df.data.Data;
 import org.apache.mahout.classifier.df.data.DataLoader;
 import org.apache.mahout.classifier.df.data.Dataset;
@@ -59,15 +58,20 @@ public class BuildForest extends Configu
   private Path datasetPath;
 
   private Path outputPath;
-  private int m; // Number of variables to select at each tree-node
+
+  private Integer m; // Number of variables to select at each tree-node
+
+  private boolean complemented; // tree is complemented
   
+  private Integer minSplitNum; // minimum number for split
+
+  private Double minVarianceProportion; // minimum proportion of the total 
variance for split
+
   private int nbTrees; // Number of trees to grow
   
   private Long seed; // Random seed
   
   private boolean isPartial; // use partial data implementation
-  
-  private String builderName; // Tree builder class name
 
   @Override
   public int run(String[] args) throws IOException, ClassNotFoundException, 
InterruptedException,
@@ -77,42 +81,57 @@ public class BuildForest extends Configu
     ArgumentBuilder abuilder = new ArgumentBuilder();
     GroupBuilder gbuilder = new GroupBuilder();
     
-    Option dataOpt = 
obuilder.withLongName("data").withShortName("d").withRequired(true).withArgument(
-      
abuilder.withName("path").withMinimum(1).withMaximum(1).create()).withDescription("Data
 path").create();
-    
-    Option datasetOpt = 
obuilder.withLongName("dataset").withShortName("ds").withRequired(true).withArgument(
-      
abuilder.withName("dataset").withMinimum(1).withMaximum(1).create()).withDescription("Dataset
 path")
-        .create();
-    
-    Option selectionOpt = 
obuilder.withLongName("selection").withShortName("sl").withRequired(true)
-        
.withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create()).withDescription(
-          "Number of variables to select randomly at each tree-node").create();
-    
-    Option seedOpt = 
obuilder.withLongName("seed").withShortName("sd").withRequired(false).withArgument(
-      
abuilder.withName("seed").withMinimum(1).withMaximum(1).create()).withDescription(
-      "Optional, seed value used to initialise the Random number 
generator").create();
+    Option dataOpt = 
obuilder.withLongName("data").withShortName("d").withRequired(true)
+        
.withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+        .withDescription("Data path").create();
+    
+    Option datasetOpt = 
obuilder.withLongName("dataset").withShortName("ds").withRequired(true)
+        
.withArgument(abuilder.withName("dataset").withMinimum(1).withMaximum(1).create())
+        .withDescription("Dataset path").create();
+    
+    Option selectionOpt = 
obuilder.withLongName("selection").withShortName("sl").withRequired(false)
+        
.withArgument(abuilder.withName("m").withMinimum(1).withMaximum(1).create())
+        .withDescription("Optional, Number of variables to select randomly at 
each tree-node.\n" +
+        "For classification problem, the default is square root of the number 
of explanatory variables.\n" +
+        "For regression problem, the default is 1/3 of the number of 
explanatory variables.").create();
+
+    Option noCompleteOpt = 
obuilder.withLongName("no-complete").withShortName("nc").withRequired(false)
+        .withDescription("Optional, The tree is not complemented").create();
+
+    Option minSplitOpt = 
obuilder.withLongName("minsplit").withShortName("ms").withRequired(false)
+        
.withArgument(abuilder.withName("minsplit").withMinimum(1).withMaximum(1).create())
+        .withDescription("Optional, The tree-node is not divided, if the 
branching data size is " +
+        "smaller than this value.\nThe default is 2.").create();
+
+    Option minPropOpt = 
obuilder.withLongName("minprop").withShortName("mp").withRequired(false)
+        
.withArgument(abuilder.withName("minprop").withMinimum(1).withMaximum(1).create())
+        .withDescription("Optional, The tree-node is not divided, if the 
proportion of the " +
+        "variance of branching data is smaller than this value.\n" +
+        "In the case of a regression problem, this value is used. " +
+        "The default is 1/1000(0.001).").create();
+
+    Option seedOpt = 
obuilder.withLongName("seed").withShortName("sd").withRequired(false)
+        
.withArgument(abuilder.withName("seed").withMinimum(1).withMaximum(1).create())
+        .withDescription("Optional, seed value used to initialise the Random 
number generator").create();
     
     Option partialOpt = 
obuilder.withLongName("partial").withShortName("p").withRequired(false)
         .withDescription("Optional, use the Partial Data 
implementation").create();
     
-    Option nbtreesOpt = 
obuilder.withLongName("nbtrees").withShortName("t").withRequired(true).withArgument(
-      
abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create()).withDescription(
-      "Number of trees to grow").create();
-    
-    Option outputOpt = 
obuilder.withLongName("output").withShortName("o").withRequired(true).withArgument(
-        abuilder.withName("path").withMinimum(1).withMaximum(1).create()).
-        withDescription("Output path, will contain the Decision 
Forest").create();
-
-    Option builderOpt = 
obuilder.withLongName("builder").withShortName("b").withRequired(false)
-      
.withArgument(abuilder.withName("builder").withMinimum(1).withMaximum(1).create()).
-      withDescription("Tree builder class name").create();
+    Option nbtreesOpt = 
obuilder.withLongName("nbtrees").withShortName("t").withRequired(true)
+        
.withArgument(abuilder.withName("nbtrees").withMinimum(1).withMaximum(1).create())
+        .withDescription("Number of trees to grow").create();
+    
+    Option outputOpt = 
obuilder.withLongName("output").withShortName("o").withRequired(true)
+        
.withArgument(abuilder.withName("path").withMinimum(1).withMaximum(1).create())
+        .withDescription("Output path, will contain the Decision 
Forest").create();
 
-    Option helpOpt = obuilder.withLongName("help").withDescription("Print out 
help").withShortName("h")
-        .create();
+    Option helpOpt = obuilder.withLongName("help").withShortName("h")
+        .withDescription("Print out help").create();
     
     Group group = 
gbuilder.withName("Options").withOption(dataOpt).withOption(datasetOpt)
-        
.withOption(selectionOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
-        
.withOption(outputOpt).withOption(builderOpt).withOption(helpOpt).create();
+        
.withOption(selectionOpt).withOption(noCompleteOpt).withOption(minSplitOpt)
+        
.withOption(minPropOpt).withOption(seedOpt).withOption(partialOpt).withOption(nbtreesOpt)
+        .withOption(outputOpt).withOption(helpOpt).create();
     
     try {
       Parser parser = new Parser();
@@ -128,28 +147,39 @@ public class BuildForest extends Configu
       String dataName = cmdLine.getValue(dataOpt).toString();
       String datasetName = cmdLine.getValue(datasetOpt).toString();
       String outputName = cmdLine.getValue(outputOpt).toString();
-      m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
       nbTrees = Integer.parseInt(cmdLine.getValue(nbtreesOpt).toString());
       
+      if (cmdLine.hasOption(selectionOpt)) {
+        m = Integer.parseInt(cmdLine.getValue(selectionOpt).toString());
+      }
+      if (cmdLine.hasOption(noCompleteOpt)) {
+        complemented = false;
+      } else {
+        complemented = true;
+      }
+      if (cmdLine.hasOption(minSplitOpt)) {
+        minSplitNum = 
Integer.parseInt(cmdLine.getValue(minSplitOpt).toString());
+      }
+      if (cmdLine.hasOption(minPropOpt)) {
+        minVarianceProportion = 
Double.parseDouble(cmdLine.getValue(minPropOpt).toString());
+      }
       if (cmdLine.hasOption(seedOpt)) {
         seed = Long.valueOf(cmdLine.getValue(seedOpt).toString());
       }
 
-      if (cmdLine.hasOption(builderOpt)) {
-        builderName = cmdLine.getValue(builderOpt).toString();
-      }
-
       if (log.isDebugEnabled()) {
         log.debug("data : {}", dataName);
         log.debug("dataset : {}", datasetName);
         log.debug("output : {}", outputName);
         log.debug("m : {}", m);
+        log.debug("complemented : {}", complemented);
+        log.debug("minSplitNum : {}", minSplitNum);
+        log.debug("minVarianceProportion : {}", minVarianceProportion);
         log.debug("seed : {}", seed);
         log.debug("nbtrees : {}", nbTrees);
         log.debug("isPartial : {}", isPartial);
-        log.debug("builder : {}", builderName);
       }
-     
+
       dataPath = new Path(dataName);
       datasetPath = new Path(datasetName);
       outputPath = new Path(outputName);
@@ -174,13 +204,16 @@ public class BuildForest extends Configu
       return;
     }
 
-    TreeBuilder treeBuilder;
-    if (builderName == null) {
-      treeBuilder = new DefaultTreeBuilder();
-      ((DefaultTreeBuilder) treeBuilder).setM(m);
-    } else {
-      Class<?> clazz = Class.forName(builderName);
-      treeBuilder = (TreeBuilder) clazz.newInstance();
+    DecisionTreeBuilder treeBuilder = new DecisionTreeBuilder();
+    if (m != null) {
+      treeBuilder.setM(m);
+    }
+    treeBuilder.setComplemented(complemented);
+    if (minSplitNum != null) {
+      treeBuilder.setMinSplitNum(minSplitNum);
+    }
+    if (minVarianceProportion != null) {
+      treeBuilder.setMinVarianceProportion(minVarianceProportion);
     }
     
     Builder forestBuilder;


Reply via email to