Author: tdunning
Date: Wed Sep 22 04:32:30 2010
New Revision: 999754

URL: http://svn.apache.org/viewvc?rev=999754&view=rev
Log:
Improved calls to dissection.  Slight cleanups.  Closed model before dissection.

Modified:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=999754&r1=999753&r2=999754&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
 Wed Sep 22 04:32:30 2010
@@ -28,6 +28,7 @@ import org.apache.lucene.analysis.standa
 import org.apache.lucene.analysis.tokenattributes.TermAttribute;
 import org.apache.lucene.util.Version;
 import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.RandomAccessSparseVector;
 import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.function.Functions;
@@ -171,20 +172,25 @@ public class TrainNewsGroups {
         averageCorrect = state.percentCorrect();
         averageLL = state.logLikelihood();
 
-        maxBeta = state.getModels().get(0).getBeta().aggregate(Functions.MAX, 
Functions.IDENTITY);
-        nonZeros = 
state.getModels().get(0).getBeta().aggregate(Functions.PLUS, new 
UnaryFunction() {
+        OnlineLogisticRegression model = state.getModels().get(0);
+        // finish off pending regularization
+        model.close();
+        
+        Matrix beta = model.getBeta();
+        maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
+        nonZeros = beta.aggregate(Functions.PLUS, new UnaryFunction() {
           @Override
           public double apply(double v) {
             return Math.abs(v) > 1e-6 ? 1 : 0;
           }
         });
-        positive = 
state.getModels().get(0).getBeta().aggregate(Functions.PLUS, new 
UnaryFunction() {
+        positive = beta.aggregate(Functions.PLUS, new UnaryFunction() {
           @Override
           public double apply(double v) {
             return v > 0 ? 1 : 0;
           }
         });
-        norm = state.getModels().get(0).getBeta().aggregate(Functions.PLUS, 
Functions.ABS);
+        norm = beta.aggregate(Functions.PLUS, Functions.ABS);
 
         lambda = learningAlgorithm.getBest().getMappedParams()[0];
         mu = learningAlgorithm.getBest().getMappedParams()[1];
@@ -215,13 +221,16 @@ public class TrainNewsGroups {
     encoder.setTraceDictionary(traceDictionary);
     bias.setTraceDictionary(traceDictionary);
     int k = 0;
+    CrossFoldLearner model = 
learningAlgorithm.getBest().getPayload().getLearner();
+    model.close();
+    
     for (File file : permute(files, rand).subList(0, 500)) {
       String ng = file.getParentFile().getName();
       int actual = newsGroups.intern(ng);
 
       traceDictionary.clear();
       Vector v = encodeFeatureVector(file, actual, leakType);
-      md.update(v, traceDictionary, 
learningAlgorithm.getBest().getPayload().getLearner());
+      md.update(v, traceDictionary, model);
       if (k % 100 == 0) {
         System.out.printf("%d\t%d\n", k, traceDictionary.size());
       }


Reply via email to