Author: tdunning
Date: Wed Sep 22 04:32:30 2010
New Revision: 999754
URL: http://svn.apache.org/viewvc?rev=999754&view=rev
Log:
Improved calls to dissection. Slight cleanups. Closed model before dissection.
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java?rev=999754&r1=999753&r2=999754&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainNewsGroups.java
Wed Sep 22 04:32:30 2010
@@ -28,6 +28,7 @@ import org.apache.lucene.analysis.standa
import org.apache.lucene.analysis.tokenattributes.TermAttribute;
import org.apache.lucene.util.Version;
import org.apache.mahout.ep.State;
+import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.RandomAccessSparseVector;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;
@@ -171,20 +172,25 @@ public class TrainNewsGroups {
averageCorrect = state.percentCorrect();
averageLL = state.logLikelihood();
- maxBeta = state.getModels().get(0).getBeta().aggregate(Functions.MAX,
Functions.IDENTITY);
- nonZeros =
state.getModels().get(0).getBeta().aggregate(Functions.PLUS, new
UnaryFunction() {
+ OnlineLogisticRegression model = state.getModels().get(0);
+ // finish off pending regularization
+ model.close();
+
+ Matrix beta = model.getBeta();
+ maxBeta = beta.aggregate(Functions.MAX, Functions.ABS);
+ nonZeros = beta.aggregate(Functions.PLUS, new UnaryFunction() {
@Override
public double apply(double v) {
return Math.abs(v) > 1e-6 ? 1 : 0;
}
});
- positive =
state.getModels().get(0).getBeta().aggregate(Functions.PLUS, new
UnaryFunction() {
+ positive = beta.aggregate(Functions.PLUS, new UnaryFunction() {
@Override
public double apply(double v) {
return v > 0 ? 1 : 0;
}
});
- norm = state.getModels().get(0).getBeta().aggregate(Functions.PLUS,
Functions.ABS);
+ norm = beta.aggregate(Functions.PLUS, Functions.ABS);
lambda = learningAlgorithm.getBest().getMappedParams()[0];
mu = learningAlgorithm.getBest().getMappedParams()[1];
@@ -215,13 +221,16 @@ public class TrainNewsGroups {
encoder.setTraceDictionary(traceDictionary);
bias.setTraceDictionary(traceDictionary);
int k = 0;
+ CrossFoldLearner model =
learningAlgorithm.getBest().getPayload().getLearner();
+ model.close();
+
for (File file : permute(files, rand).subList(0, 500)) {
String ng = file.getParentFile().getName();
int actual = newsGroups.intern(ng);
traceDictionary.clear();
Vector v = encodeFeatureVector(file, actual, leakType);
- md.update(v, traceDictionary,
learningAlgorithm.getBest().getPayload().getLearner());
+ md.update(v, traceDictionary, model);
if (k % 100 == 0) {
System.out.printf("%d\t%d\n", k, traceDictionary.size());
}