Author: gsingers
Date: Sat Nov 12 08:19:18 2011
New Revision: 1201223

URL: http://svn.apache.org/viewvc?rev=1201223&view=rev
Log:
MAHOUT-851: add in SGD example for ASF email

Added:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
      - copied, changed from r1200329, 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
Modified:
    mahout/trunk/examples/bin/build-asf-email.sh
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java

Modified: mahout/trunk/examples/bin/build-asf-email.sh
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/bin/build-asf-email.sh?rev=1201223&r1=1201222&r2=1201223&view=diff
==============================================================================
--- mahout/trunk/examples/bin/build-asf-email.sh (original)
+++ mahout/trunk/examples/bin/build-asf-email.sh Sat Nov 12 08:19:18 2011
@@ -109,11 +109,16 @@ elif [ "x$alg" == "xclassification" ]; t
   echo "Please select a number to choose the corresponding algorithm to run"
   echo "1. ${algorithm[0]}"
   echo "2. ${algorithm[1]}"
-#  echo "3. ${algorithm[2]}"
+  echo "3. ${algorithm[2]}"
   read -p "Enter your choice : " choice
 
   echo "ok. You chose $choice and we'll use ${algorithm[$choice-1]}"
   classAlg=${algorithm[$choice-1]}
+
+  if [ "x$classAlg" == "xsgd"  ]; then
+    echo "How many labels/projects are there in the data set:"
+    read -p "Enter your choice : " numLabels
+  fi
   #Convert mail to be formatted as:
   # label\ttext
   # One per line
@@ -167,6 +172,7 @@ elif [ "x$alg" == "xclassification" ]; t
     TRAIN="$SPLIT/train"
     TEST="$SPLIT/test"
     TEST_OUT="$CLASS/test-results"
+    MODELS="$CLASS/models"
     LABEL="$SPLIT/labels"
     if [ "x$OVER" == "xover" ] || [ ! -e "$MAIL_OUT/chunk-0" ]; then
       echo "Converting Mail files to Sequence Files"
@@ -182,12 +188,14 @@ elif [ "x$alg" == "xclassification" ]; t
       echo "Creating training and test inputs from $SEQ2SPLABEL"
       $MAHOUT split --input $SEQ2SPLABEL --trainingOutput $TRAIN --testOutput 
$TEST --randomSelectionPct 20 --overwrite --sequenceFiles
     fi
-    MODEL="$CLASS/model"
+    MODEL="$MODELS/asf.model"
+
 
     echo "Running SGD Training"
-    #$MAHOUT trainnb -i $TRAIN -o $MODEL --extractLabels --labelIndex $LABEL 
--overwrite
+    $MAHOUT org.apache.mahout.classifier.sgd.TrainASFEmail $TRAIN $MODELS 
$numLabels 5000
     echo "Running Test"
-#$MAHOUT testnb -i $TEST -o $TEST_OUT -m $MODEL --labelIndex $LABEL --overwrite
+    $MODEL="$MODELS/asf.model"
+    $MAHOUT org.apache.mahout.classifier.sgd.TestASFEmail --input $TEST 
--model $MODEL
 
   fi
 fi

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java?rev=1201223&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDHelper.java
 Sat Nov 12 08:19:18 2011
@@ -0,0 +1,149 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.DoubleFunction;
+import org.apache.mahout.math.function.Functions;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import java.util.Set;
+
+/**
+ *
+ *
+ **/
+public class SGDHelper {
+  private static final String[] LEAK_LABELS = {"none", "month-year", 
"day-month-year"};
+
+  public static void dissect(int leakType,
+                             Dictionary newsGroups,
+                             AdaptiveLogisticRegression learningAlgorithm,
+                             Iterable<File> files, Multiset<String> 
overallCounts) throws IOException {
+    CrossFoldLearner model = 
learningAlgorithm.getBest().getPayload().getLearner();
+    model.close();
+
+    Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
+    ModelDissector md = new ModelDissector();
+
+    NewsgroupHelper helper = new NewsgroupHelper();
+    helper.getEncoder().setTraceDictionary(traceDictionary);
+    helper.getBias().setTraceDictionary(traceDictionary);
+
+    for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
+      String ng = file.getParentFile().getName();
+      int actual = newsGroups.intern(ng);
+
+      traceDictionary.clear();
+      Vector v = helper.encodeFeatureVector(file, actual, leakType, 
overallCounts);
+      md.update(v, traceDictionary, model);
+    }
+
+    List<String> ngNames = Lists.newArrayList(newsGroups.values());
+    List<ModelDissector.Weight> weights = md.summary(100);
+    System.out.println("============");
+    System.out.println("Model Dissection");
+    for (ModelDissector.Weight w : weights) {
+      System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n",
+                        w.getFeature(), w.getWeight(), 
ngNames.get(w.getMaxImpact() + 1),
+                        w.getCategory(1), w.getWeight(1), w.getCategory(2), 
w.getWeight(2));
+    }
+  }
+
+  public static List<File> permute(Iterable<File> files, Random rand) {
+    List<File> r = Lists.newArrayList();
+    for (File file : files) {
+      int i = rand.nextInt(r.size() + 1);
+      if (i == r.size()) {
+        r.add(file);
+      } else {
+        r.add(r.get(i));
+        r.set(i, file);
+      }
+    }
+    return r;
+  }
+
+  static void analyzeState(SGDInfo info, int leakType, int k, 
State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best) throws 
IOException {
+    int bump = info.bumps[(int) Math.floor(info.step) % info.bumps.length];
+    int scale = (int) Math.pow(10, Math.floor(info.step / info.bumps.length));
+    double maxBeta;
+    double nonZeros;
+    double positive;
+    double norm;
+
+    double lambda = 0;
+    double mu = 0;
+
+    if (best != null) {
+      CrossFoldLearner state = best.getPayload().getLearner();
+      info.averageCorrect = state.percentCorrect();
+      info.averageLL = state.logLikelihood();
+
+      OnlineLogisticRegression model = state.getModels().get(0);
+      // finish off pending regularization
+      model.close();
+
+      Matrix beta = model.getBeta();
+      maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
+      nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+        @Override
+        public double apply(double v) {
+          return Math.abs(v) > 1.0e-6 ? 1 : 0;
+        }
+      });
+      positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
+        @Override
+        public double apply(double v) {
+          return v > 0 ? 1 : 0;
+        }
+      });
+      norm = beta.aggregate(Functions.PLUS, Functions.ABS);
+
+      lambda = best.getMappedParams()[0];
+      mu = best.getMappedParams()[1];
+    } else {
+      maxBeta = 0;
+      nonZeros = 0;
+      positive = 0;
+      norm = 0;
+    }
+    if (k % (bump * scale) == 0) {
+      if (best != null) {
+        ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
+                best.getPayload().getLearner().getModels().get(0));
+      }
+
+      info.step += 0.25;
+      System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, 
nonZeros, positive, norm, lambda, mu);
+      System.out.printf("%d\t%.3f\t%.2f\t%s\n",
+        k, info.averageLL, info.averageCorrect * 100, LEAK_LABELS[leakType % 
3]);
+    }
+  }
+
+}

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java?rev=1201223&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/SGDInfo.java
 Sat Nov 12 08:19:18 2011
@@ -0,0 +1,30 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+/**
+ *
+ *
+ **/
+class SGDInfo {
+  double averageLL = 0;
+  double averageCorrect = 0;
+  double step = 0;
+  int[] bumps = {1, 2, 5};
+
+}

Copied: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
 (from r1200329, 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java)
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java?p2=mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java&p1=mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java&r1=1200329&r2=1201223&rev=1201223&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestNewsGroups.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TestASFEmail.java
 Sat Nov 12 08:19:18 2011
@@ -18,7 +18,6 @@
 package org.apache.mahout.classifier.sgd;
 
 import com.google.common.collect.HashMultiset;
-import com.google.common.collect.Lists;
 import com.google.common.collect.Multiset;
 import org.apache.commons.cli2.CommandLine;
 import org.apache.commons.cli2.Group;
@@ -28,31 +27,37 @@ import org.apache.commons.cli2.builder.D
 import org.apache.commons.cli2.builder.GroupBuilder;
 import org.apache.commons.cli2.commandline.Parser;
 import org.apache.commons.cli2.util.HelpFormatter;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
 import org.apache.mahout.classifier.ClassifierResult;
 import org.apache.mahout.classifier.ResultAnalyzer;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
 import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.VectorWritable;
 import org.apache.mahout.vectorizer.encoders.Dictionary;
 
 import java.io.File;
 import java.io.FileInputStream;
 import java.io.IOException;
 import java.io.PrintWriter;
-import java.util.Arrays;
-import java.util.List;
 
 /**
- * Run the 20 news groups test data through SGD, as trained by {@link 
org.apache.mahout.classifier.sgd.TrainNewsGroups}.
+ * Run the 20 news groups test data through SGD, as trained by {@link 
TrainNewsGroups}.
  */
-public final class TestNewsGroups {
+public final class TestASFEmail {
 
   private String inputFile;
   private String modelFile;
 
-  private TestNewsGroups() {
+  private TestASFEmail() {
   }
 
   public static void main(String[] args) throws IOException {
-    TestNewsGroups runner = new TestNewsGroups();
+    TestASFEmail runner = new TestASFEmail();
     if (runner.parseArgs(args)) {
       runner.run(new PrintWriter(System.out, true));
     }
@@ -65,30 +70,35 @@ public final class TestNewsGroups {
     OnlineLogisticRegression classifier = ModelSerializer.readBinary(new 
FileInputStream(modelFile), OnlineLogisticRegression.class);
 
 
-    Dictionary newsGroups = new Dictionary();
+    Dictionary asfDictionary = new Dictionary();
     Multiset<String> overallCounts = HashMultiset.create();
-
-    List<File> files = Lists.newArrayList();
-    for (File newsgroup : base.listFiles()) {
-      if (newsgroup.isDirectory()) {
-        newsGroups.intern(newsgroup.getName());
-        files.addAll(Arrays.asList(newsgroup.listFiles()));
-      }
+    Configuration conf = new Configuration();
+    SequenceFileDirIterator<Text, VectorWritable> iter = new 
SequenceFileDirIterator<Text, VectorWritable>(new Path(base.toString()), 
PathType.LIST, PathFilters.partFilter(),
+            null, true, conf);
+
+    long numItems = 0;
+    while (iter.hasNext()) {
+      Pair<Text, VectorWritable> next = iter.next();
+      asfDictionary.intern(next.getFirst().toString());
+      numItems++;
     }
-    System.out.printf("%d test files\n", files.size());
-    ResultAnalyzer ra = new ResultAnalyzer(newsGroups.values(), "DEFAULT");
-    for (File file : files) {
-      String ng = file.getParentFile().getName();
 
-      int actual = newsGroups.intern(ng);
+    System.out.printf("%d test files\n", numItems);
+    ResultAnalyzer ra = new ResultAnalyzer(asfDictionary.values(), "DEFAULT");
+    iter = new SequenceFileDirIterator<Text, VectorWritable>(new 
Path(base.toString()), PathType.LIST, PathFilters.partFilter(),
+            null, true, conf);
+    while (iter.hasNext()){
+      Pair<Text, VectorWritable> next = iter.next();
+      String ng = next.getFirst().toString();
+
+      int actual = asfDictionary.intern(ng);
       NewsgroupHelper helper = new NewsgroupHelper();
-      Vector input = helper.encodeFeatureVector(file, actual, 0, 
overallCounts);//no leak type ensures this is a normal vector
-      Vector result = classifier.classifyFull(input);
+      Vector result = classifier.classifyFull(next.getSecond().get());
       int cat = result.maxValueIndex();
       double score = result.maxValue();
-      double ll = classifier.logLikelihood(actual, input);
-      ClassifierResult cr = new ClassifierResult(newsGroups.values().get(cat), 
score, ll);
-      ra.addInstance(newsGroups.values().get(actual), cr);
+      double ll = classifier.logLikelihood(actual, next.getSecond().get());
+      ClassifierResult cr = new 
ClassifierResult(asfDictionary.values().get(cat), score, ll);
+      ra.addInstance(asfDictionary.values().get(actual), cr);
 
     }
     output.printf("%s\n\n", ra.toString());

Added: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java?rev=1201223&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
 (added)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainASFEmail.java
 Sat Nov 12 08:19:18 2011
@@ -0,0 +1,126 @@
+package org.apache.mahout.classifier.sgd;
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import com.google.common.collect.Ordering;
+import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.Path;
+import org.apache.hadoop.io.Text;
+import org.apache.mahout.common.Pair;
+import org.apache.mahout.common.iterator.sequencefile.PathFilters;
+import org.apache.mahout.common.iterator.sequencefile.PathType;
+import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterator;
+import org.apache.mahout.ep.State;
+import org.apache.mahout.math.VectorWritable;
+import org.apache.mahout.vectorizer.encoders.Dictionary;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ *
+ *
+ **/
+public class TrainASFEmail {
+  private static final String[] LEAK_LABELS = {"none", "month-year", 
"day-month-year"};
+
+  private static Multiset<String> overallCounts;
+
+  private TrainASFEmail() {
+  }
+
+  public static void main(String[] args) throws IOException {
+    File base = new File(args[0]);
+
+    overallCounts = HashMultiset.create();
+    File output = new File(args[1]);
+    output.mkdirs();
+    int numCats = Integer.parseInt(args[2]);
+    int cardinality = Integer.parseInt(args[3]);
+
+    int leakType = 0;
+    if (args.length > 4) {
+      leakType = Integer.parseInt(args[4]);
+    }
+
+    Dictionary asfDictionary = new Dictionary();
+
+
+    AdaptiveLogisticRegression learningAlgorithm = new 
AdaptiveLogisticRegression(numCats, cardinality, new L1());
+    learningAlgorithm.setInterval(800);
+    learningAlgorithm.setAveragingWindow(500);
+
+    //We ran seq2encoded and split input already, so let's just build up the 
dictionary
+    Configuration conf = new Configuration();
+    SequenceFileDirIterator<Text, VectorWritable> iter = new 
SequenceFileDirIterator<Text, VectorWritable>(new Path(base.toString()), 
PathType.LIST, PathFilters.partFilter(),
+            null, true, conf);
+    long numItems = 0;
+    while (iter.hasNext()) {
+      Pair<Text, VectorWritable> next = iter.next();
+      asfDictionary.intern(next.getFirst().toString());
+      numItems++;
+    }
+
+    System.out.printf("%d training files\n", numItems);
+
+
+    int k = 0;
+    SGDInfo info = new SGDInfo();
+
+    iter = new SequenceFileDirIterator<Text, VectorWritable>(new 
Path(base.toString()), PathType.LIST, PathFilters.partFilter(),
+            null, true, conf);
+    while (iter.hasNext()) {
+      Pair<Text, VectorWritable> next = iter.next();
+      String ng = next.getFirst().toString();
+      int actual = asfDictionary.intern(ng);
+      //we already have encoded
+      learningAlgorithm.train(actual, next.getSecond().get());
+      k++;
+      State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = 
learningAlgorithm.getBest();
+
+      SGDHelper.analyzeState(info, leakType, k, best);
+    }
+    learningAlgorithm.close();
+    //TODO: how to dissection since we aren't processing the files here
+    //SGDHelper.dissect(leakType, asfDictionary, learningAlgorithm, files, 
overallCounts);
+    System.out.println("exiting main, writing model to " + output);
+
+    ModelSerializer.writeBinary(output + "/asf.model",
+            
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+
+    List<Integer> counts = Lists.newArrayList();
+    System.out.printf("Word counts\n");
+    for (String count : overallCounts.elementSet()) {
+      counts.add(overallCounts.count(count));
+    }
+    Collections.sort(counts, Ordering.natural().reverse());
+    k = 0;
+    for (Integer count : counts) {
+      System.out.printf("%d\t%d\n", k, count);
+      k++;
+      if (k > 1000) {
+        break;
+      }
+    }
+  }
+}

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=1201223&r1=1201222&r2=1201223&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 Sat Nov 12 08:19:18 2011
@@ -19,15 +19,10 @@ package org.apache.mahout.classifier.sgd
 
 import com.google.common.collect.HashMultiset;
 import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import com.google.common.collect.Multiset;
 import com.google.common.collect.Ordering;
-
 import org.apache.mahout.ep.State;
-import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.function.Functions;
-import org.apache.mahout.math.function.DoubleFunction;
 import org.apache.mahout.vectorizer.encoders.Dictionary;
 
 import java.io.File;
@@ -35,9 +30,6 @@ import java.io.IOException;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.Set;
 
 /**
  * Reads and trains an adaptive logistic regression model on the 20 newsgroups 
data.
@@ -45,10 +37,10 @@ import java.util.Set;
  * data.  The optional second argument, leakType, defines which classes of 
features to use.
  * Importantly, leakType controls whether a synthetic date is injected into 
the data as
  * a target leak and if so, how.
- * <p>
+ * <p/>
  * The value of leakType % 3 determines whether the target leak is injected 
according to
  * the following table:
- * <p>
+ * <p/>
  * <table>
  * <tr><td valign='top'>0</td><td>No leak injected</td></tr>
  * <tr><td valign='top'>1</td><td>Synthetic date injected in MMM-yyyy format. 
This will be a single token and
@@ -57,16 +49,16 @@ import java.util.Set;
  * and thus there are more leak symbols that need to be learned.  Ultimately 
this is just
  * as big a leak as case 1.</td></tr>
  * </table>
- * <p>
+ * <p/>
  * Leaktype also determines what other text will be indexed.  If leakType is 
greater
  * than or equal to 6, then neither headers nor text body will be used for 
features and the leak is the only
  * source of data.  If leakType is greater than or equal to 3, then subject 
words will be used as features.
  * If leakType is less than 3, then both subject and body text will be used as 
features.
- * <p>
+ * <p/>
  * A leakType of 0 gives no leak and all textual features.
- * <p>
+ * <p/>
  * See the following table for a summary of commonly used values for leakType
- * <p>
+ * <p/>
  * <table>
  * 
<tr><td><b>leakType</b></td><td><b>Leak?</b></td><td><b>Subject?</b></td><td><b>Body?</b></td></tr>
  * <tr><td colspan=4><hr></td></tr>
@@ -86,8 +78,6 @@ import java.util.Set;
  */
 public final class TrainNewsGroups {
 
-  private static final String[] LEAK_LABELS = {"none", "month-year", 
"day-month-year"};
-
   private static Multiset<String> overallCounts;
 
   private TrainNewsGroups() {
@@ -120,13 +110,11 @@ public final class TrainNewsGroups {
     }
     Collections.shuffle(files);
     System.out.printf("%d training files\n", files.size());
-
-    double averageLL = 0;
-    double averageCorrect = 0;
+    SGDInfo info = new SGDInfo();
 
     int k = 0;
-    double step = 0;
-    int[] bumps = {1, 2, 5};
+
+
     for (File file : files) {
       String ng = file.getParentFile().getName();
       int actual = newsGroups.intern(ng);
@@ -135,69 +123,16 @@ public final class TrainNewsGroups {
       learningAlgorithm.train(actual, v);
 
       k++;
-
-      int bump = bumps[(int) Math.floor(step) % bumps.length];
-      int scale = (int) Math.pow(10, Math.floor(step / bumps.length));
       State<AdaptiveLogisticRegression.Wrapper, CrossFoldLearner> best = 
learningAlgorithm.getBest();
-      double maxBeta;
-      double nonZeros;
-      double positive;
-      double norm;
-
-      double lambda = 0;
-      double mu = 0;
-
-      if (best != null) {
-        CrossFoldLearner state = best.getPayload().getLearner();
-        averageCorrect = state.percentCorrect();
-        averageLL = state.logLikelihood();
-
-        OnlineLogisticRegression model = state.getModels().get(0);
-        // finish off pending regularization
-        model.close();
-        
-        Matrix beta = model.getBeta();
-        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
-        nonZeros = beta.aggregate(Functions.PLUS, new DoubleFunction() {
-          @Override
-          public double apply(double v) {
-            return Math.abs(v) > 1.0e-6 ? 1 : 0;
-          }
-        });
-        positive = beta.aggregate(Functions.PLUS, new DoubleFunction() {
-          @Override
-          public double apply(double v) {
-            return v > 0 ? 1 : 0;
-          }
-        });
-        norm = beta.aggregate(Functions.PLUS, Functions.ABS);
-
-        lambda = learningAlgorithm.getBest().getMappedParams()[0];
-        mu = learningAlgorithm.getBest().getMappedParams()[1];
-      } else {
-        maxBeta = 0;
-        nonZeros = 0;
-        positive = 0;
-        norm = 0;
-      }
-      if (k % (bump * scale) == 0) {
-        if (learningAlgorithm.getBest() != null) {
-          ModelSerializer.writeBinary("/tmp/news-group-" + k + ".model",
-                                      
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
-        }
-
-        step += 0.25;
-        System.out.printf("%.2f\t%.2f\t%.2f\t%.2f\t%.8g\t%.8g\t", maxBeta, 
nonZeros, positive, norm, lambda, mu);
-        System.out.printf("%d\t%.3f\t%.2f\t%s\n",
-          k, averageLL, averageCorrect * 100, LEAK_LABELS[leakType % 3]);
-      }
+
+      SGDHelper.analyzeState(info, leakType, k, best);
     }
     learningAlgorithm.close();
-    dissect(leakType, newsGroups, learningAlgorithm, files);
+    SGDHelper.dissect(leakType, newsGroups, learningAlgorithm, files, 
overallCounts);
     System.out.println("exiting main");
 
     ModelSerializer.writeBinary("/tmp/news-group.model",
-                                
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
+            
learningAlgorithm.getBest().getPayload().getLearner().getModels().get(0));
 
     List<Integer> counts = Lists.newArrayList();
     System.out.printf("Word counts\n");
@@ -215,52 +150,5 @@ public final class TrainNewsGroups {
     }
   }
 
-  private static void dissect(int leakType,
-                              Dictionary newsGroups,
-                              AdaptiveLogisticRegression learningAlgorithm,
-                              Iterable<File> files) throws IOException {
-    CrossFoldLearner model = 
learningAlgorithm.getBest().getPayload().getLearner();
-    model.close();
-
-    Map<String, Set<Integer>> traceDictionary = Maps.newTreeMap();
-    ModelDissector md = new ModelDissector();
-
-    NewsgroupHelper helper = new NewsgroupHelper();
-    helper.getEncoder().setTraceDictionary(traceDictionary);
-    helper.getBias().setTraceDictionary(traceDictionary);
-
-    for (File file : permute(files, helper.getRandom()).subList(0, 500)) {
-      String ng = file.getParentFile().getName();
-      int actual = newsGroups.intern(ng);
-
-      traceDictionary.clear();
-      Vector v = helper.encodeFeatureVector(file, actual, leakType, 
overallCounts);
-      md.update(v, traceDictionary, model);
-    }
-
-    List<String> ngNames = Lists.newArrayList(newsGroups.values());
-    List<ModelDissector.Weight> weights = md.summary(100);
-    System.out.println("============");
-    System.out.println("Model Dissection");
-    for (ModelDissector.Weight w : weights) {
-      System.out.printf("%s\t%.1f\t%s\t%.1f\t%s\t%.1f\t%s\n",
-                        w.getFeature(), w.getWeight(), 
ngNames.get(w.getMaxImpact() + 1),
-                        w.getCategory(1), w.getWeight(1), w.getCategory(2), 
w.getWeight(2));
-    }
-  }
-
-  private static List<File> permute(Iterable<File> files, Random rand) {
-    List<File> r = Lists.newArrayList();
-    for (File file : files) {
-      int i = rand.nextInt(r.size() + 1);
-      if (i == r.size()) {
-        r.add(file);
-      } else {
-        r.add(r.get(i));
-        r.set(i, file);
-      }
-    }
-    return r;
-  }
 
 }


Reply via email to