Author: tdunning
Date: Fri Sep 24 00:26:47 2010
New Revision: 1000671

URL: http://svn.apache.org/viewvc?rev=1000671&view=rev
Log:
Touched up parameters for example program

Modified:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1000671&r1=1000670&r2=1000671&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 Fri Sep 24 00:26:47 2010
@@ -17,7 +17,6 @@
 
 package org.apache.mahout.classifier.sgd;
 
-import com.google.common.base.Splitter;
 import com.google.common.collect.ConcurrentHashMultiset;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
@@ -46,6 +45,7 @@ import java.io.Reader;
 import java.io.StringReader;
 import java.text.SimpleDateFormat;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.Date;
 import java.util.List;
 import java.util.Map;
@@ -106,8 +106,6 @@ public class TrainNewsGroups {
 
   private static final Random rand = new Random();
 
-  private static final Splitter ON_COLON = Splitter.on(":");
-
   private static final String[] leakLabels = {"none", "month-year", 
"day-month-year"};
   private static final SimpleDateFormat[] df = new SimpleDateFormat[]{
     new SimpleDateFormat(""),
@@ -131,7 +129,7 @@ public class TrainNewsGroups {
 
     encoder.setProbes(2);
     AdaptiveLogisticRegression learningAlgorithm = new 
AdaptiveLogisticRegression(20, FEATURES, new L1());
-    learningAlgorithm.setInterval(200);
+    learningAlgorithm.setInterval(800);
     learningAlgorithm.setAveragingWindow(500);
 
     List<File> files = Lists.newArrayList();
@@ -139,6 +137,7 @@ public class TrainNewsGroups {
       newsGroups.intern(newsgroup.getName());
       files.addAll(Arrays.asList(newsgroup.listFiles()));
     }
+    Collections.shuffle(files);
     System.out.printf("%d training files\n", files.size());
 
     double averageLL = 0;
@@ -147,7 +146,7 @@ public class TrainNewsGroups {
     int k = 0;
     double step = 0;
     int[] bumps = new int[]{1, 2, 5};
-    for (File file : permute(files, rand).subList(0, 5000)) {
+    for (File file : files.subList(0, 10000)) {
       String ng = file.getParentFile().getName();
       int actual = newsGroups.intern(ng);
 
@@ -203,27 +202,25 @@ public class TrainNewsGroups {
       if (k % (bump * scale) == 0) {
         step += 0.25;
         System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8f\t%.8f\t", maxBeta, 
nonZeros, positive, norm, lambda, mu);
-        System.out.printf("%d\t%.3f\t%.2f\t%s\t%s\n",
-          k, averageLL, averageCorrect * 100, ng, leakLabels[leakType % 3]);
+        System.out.printf("%d\t%.3f\t%.2f\t%s\n",
+          k, averageLL, averageCorrect * 100, leakLabels[leakType % 3]);
       }
     }
     learningAlgorithm.close();
-
     dissect(leakType, newsGroups, learningAlgorithm, files);
     System.out.println("exiting main");
   }
 
   private static void dissect(int leakType, Dictionary newsGroups, 
AdaptiveLogisticRegression learningAlgorithm, List<File> files) throws 
IOException {
+    CrossFoldLearner model = 
learningAlgorithm.getBest().getPayload().getLearner();
+    model.close();
+
     Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
-    System.out.printf("starting dissection\n");
-    ModelDissector md = new 
ModelDissector(learningAlgorithm.getBest().getPayload().getLearner().numCategories());
+    ModelDissector md = new ModelDissector(model.numCategories());
 
     encoder.setTraceDictionary(traceDictionary);
     bias.setTraceDictionary(traceDictionary);
-    int k = 0;
-    CrossFoldLearner model = 
learningAlgorithm.getBest().getPayload().getLearner();
-    model.close();
-    
+
     for (File file : permute(files, rand).subList(0, 500)) {
       String ng = file.getParentFile().getName();
       int actual = newsGroups.intern(ng);
@@ -231,10 +228,6 @@ public class TrainNewsGroups {
       traceDictionary.clear();
       Vector v = encodeFeatureVector(file, actual, leakType);
       md.update(v, traceDictionary, model);
-      if (k % 100 == 0) {
-        System.out.printf("%d\t%d\n", k, traceDictionary.size());
-      }
-      k++;
     }
 
     List<String> ngNames = Lists.newArrayList(newsGroups.values());


Reply via email to