Author: tdunning
Date: Fri Sep 24 00:26:47 2010
New Revision: 1000671
URL: http://svn.apache.org/viewvc?rev=1000671&view=rev
Log:
Touched up parameters for example program
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1000671&r1=1000670&r2=1000671&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Fri Sep 24 00:26:47 2010
@@ -17,7 +17,6 @@
package org.apache.mahout.classifier.sgd;
-import com.google.common.base.Splitter;
import com.google.common.collect.ConcurrentHashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@@ -46,6 +45,7 @@ import java.io.Reader;
import java.io.StringReader;
import java.text.SimpleDateFormat;
import java.util.Arrays;
+import java.util.Collections;
import java.util.Date;
import java.util.List;
import java.util.Map;
@@ -106,8 +106,6 @@ public class TrainNewsGroups {
private static final Random rand = new Random();
- private static final Splitter ON_COLON = Splitter.on(":");
-
private static final String[] leakLabels = {"none", "month-year",
"day-month-year"};
private static final SimpleDateFormat[] df = new SimpleDateFormat[]{
new SimpleDateFormat(""),
@@ -131,7 +129,7 @@ public class TrainNewsGroups {
encoder.setProbes(2);
AdaptiveLogisticRegression learningAlgorithm = new
AdaptiveLogisticRegression(20, FEATURES, new L1());
- learningAlgorithm.setInterval(200);
+ learningAlgorithm.setInterval(800);
learningAlgorithm.setAveragingWindow(500);
List<File> files = Lists.newArrayList();
@@ -139,6 +137,7 @@ public class TrainNewsGroups {
newsGroups.intern(newsgroup.getName());
files.addAll(Arrays.asList(newsgroup.listFiles()));
}
+ Collections.shuffle(files);
System.out.printf("%d training files\n", files.size());
double averageLL = 0;
@@ -147,7 +146,7 @@ public class TrainNewsGroups {
int k = 0;
double step = 0;
int[] bumps = new int[]{1, 2, 5};
- for (File file : permute(files, rand).subList(0, 5000)) {
+ for (File file : files.subList(0, 10000)) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
@@ -203,27 +202,25 @@ public class TrainNewsGroups {
if (k % (bump * scale) == 0) {
step += 0.25;
System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8f\t%.8f\t", maxBeta,
nonZeros, positive, norm, lambda, mu);
- System.out.printf("%d\t%.3f\t%.2f\t%s\t%s\n",
- k, averageLL, averageCorrect * 100, ng, leakLabels[leakType % 3]);
+ System.out.printf("%d\t%.3f\t%.2f\t%s\n",
+ k, averageLL, averageCorrect * 100, leakLabels[leakType % 3]);
}
}
learningAlgorithm.close();
-
dissect(leakType, newsGroups, learningAlgorithm, files);
System.out.println("exiting main");
}
private static void dissect(int leakType, Dictionary newsGroups,
AdaptiveLogisticRegression learningAlgorithm, List<File> files) throws
IOException {
+ CrossFoldLearner model =
learningAlgorithm.getBest().getPayload().getLearner();
+ model.close();
+
Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
- System.out.printf("starting dissection\n");
- ModelDissector md = new
ModelDissector(learningAlgorithm.getBest().getPayload().getLearner().numCategories());
+ ModelDissector md = new ModelDissector(model.numCategories());
encoder.setTraceDictionary(traceDictionary);
bias.setTraceDictionary(traceDictionary);
- int k = 0;
- CrossFoldLearner model =
learningAlgorithm.getBest().getPayload().getLearner();
- model.close();
-
+
for (File file : permute(files, rand).subList(0, 500)) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
@@ -231,10 +228,6 @@ public class TrainNewsGroups {
traceDictionary.clear();
Vector v = encodeFeatureVector(file, actual, leakType);
md.update(v, traceDictionary, model);
- if (k % 100 == 0) {
- System.out.printf("%d\t%d\n", k, traceDictionary.size());
- }
- k++;
}
List<String> ngNames = Lists.newArrayList(newsGroups.values());