Author: tdunning
Date: Tue Sep 28 06:20:16 2010
New Revision: 1002033
URL: http://svn.apache.org/viewvc?rev=1002033&view=rev
Log:
got rid of final declarations to avoid style complaints and keep from SHOUTING
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1002033&r1=1002032&r2=1002033&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Tue Sep 28 06:20:16 2010
@@ -18,9 +18,11 @@
package org.apache.mahout.classifier.sgd;
import com.google.common.collect.ConcurrentHashMultiset;
+import com.google.common.collect.HashMultiset;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.standard.StandardAnalyzer;
@@ -116,16 +118,16 @@ public final class TrainNewsGroups {
new SimpleDateFormat("dd-MMM-yyyy HH:mm:ss")
};
- private static final Analyzer ANALYZER = new
StandardAnalyzer(Version.LUCENE_30);
- private static final FeatureVectorEncoder ENCODER = new
StaticWordValueEncoder("body");
- private static final FeatureVectorEncoder BIAS = new
ConstantValueEncoder("Intercept");
-
- private TrainNewsGroups() {
- }
+ private static Analyzer analyzer = new StandardAnalyzer(Version.LUCENE_30);
+ private static FeatureVectorEncoder encoder = new
StaticWordValueEncoder("body");
+ private static FeatureVectorEncoder bias = new
ConstantValueEncoder("Intercept");
+ private static Multiset<String> overallCounts;
public static void main(String[] args) throws IOException {
File base = new File(args[0]);
+ overallCounts = HashMultiset.create();
+
int leakType = 0;
if (args.length > 1) {
leakType = Integer.parseInt(args[1]);
@@ -133,7 +135,7 @@ public final class TrainNewsGroups {
Dictionary newsGroups = new Dictionary();
- ENCODER.setProbes(2);
+ encoder.setProbes(2);
AdaptiveLogisticRegression learningAlgorithm = new
AdaptiveLogisticRegression(20, FEATURES, new L1());
learningAlgorithm.setInterval(800);
learningAlgorithm.setAveragingWindow(500);
@@ -215,6 +217,18 @@ public final class TrainNewsGroups {
learningAlgorithm.close();
dissect(leakType, newsGroups, learningAlgorithm, files);
System.out.println("exiting main");
+
+ List<Integer> counts = Lists.newArrayList();
+ System.out.printf("Word counts\n");
+ for (String count : overallCounts.elementSet()) {
+ counts.add(overallCounts.count(count));
+ }
+ Collections.sort(counts, Ordering.natural().reverse());
+ k = 0;
+ for (Integer count : counts) {
+ System.out.printf("%d\t%d\n", k, count);
+ k++;
+ }
}
private static void dissect(int leakType,
@@ -227,8 +241,8 @@ public final class TrainNewsGroups {
Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
ModelDissector md = new ModelDissector();
- ENCODER.setTraceDictionary(traceDictionary);
- BIAS.setTraceDictionary(traceDictionary);
+ encoder.setTraceDictionary(traceDictionary);
+ bias.setTraceDictionary(traceDictionary);
for (File file : permute(files, rand).subList(0, 500)) {
String ng = file.getParentFile().getName();
@@ -254,7 +268,7 @@ public final class TrainNewsGroups {
try {
String line = reader.readLine();
Reader dateString = new StringReader(DATE_FORMATS[leakType %
3].format(new Date(date)));
- countWords(ANALYZER, words, dateString);
+ countWords(analyzer, words, dateString);
while (line != null && line.length() > 0) {
boolean countHeader = (
line.startsWith("From:") || line.startsWith("Subject:") ||
@@ -262,22 +276,22 @@ public final class TrainNewsGroups {
do {
Reader in = new StringReader(line);
if (countHeader) {
- countWords(ANALYZER, words, in);
+ countWords(analyzer, words, in);
}
line = reader.readLine();
} while (line.startsWith(" "));
}
if (leakType < 3) {
- countWords(ANALYZER, words, reader);
+ countWords(analyzer, words, reader);
}
} finally {
reader.close();
}
Vector v = new RandomAccessSparseVector(FEATURES);
- BIAS.addToVector("", 1, v);
+ bias.addToVector("", 1, v);
for (String word : words.elementSet()) {
- ENCODER.addToVector(word, Math.log(1 + words.count(word)), v);
+ encoder.addToVector(word, Math.log(1 + words.count(word)), v);
}
return v;
@@ -290,6 +304,7 @@ public final class TrainNewsGroups {
String s = ts.getAttribute(TermAttribute.class).term();
words.add(s);
}
+ overallCounts.addAll(words);
}
private static List<File> permute(Iterable<File> files, Random rand) {