Author: tdunning
Date: Mon Jan 17 03:01:26 2011
New Revision: 1059735
URL: http://svn.apache.org/viewvc?rev=1059735&view=rev
Log:
MAHOUT-496 - Added test for examples from chapter 13 of MiA
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java?rev=1059735&r1=1059734&r2=1059735&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
Mon Jan 17 03:01:26 2011
@@ -33,6 +33,7 @@ import org.apache.mahout.classifier.eval
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
+import java.io.PrintStream;
public final class RunLogistic {
@@ -41,6 +42,7 @@ public final class RunLogistic {
private static boolean showAuc;
private static boolean showScores;
private static boolean showConfusion;
+ static PrintStream output = System.out;
private RunLogistic() {
}
@@ -62,29 +64,29 @@ public final class RunLogistic {
csv.firstLine(line);
line = in.readLine();
if (showScores) {
- System.out.printf("\"%s\",\"%s\",\"%s\"\n", "target", "model-output",
"log-likelihood");
+ output.printf("\"%s\",\"%s\",\"%s\"\n", "target", "model-output",
"log-likelihood");
}
while (line != null) {
Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
int target = csv.processLine(line, v);
double score = lr.classifyScalar(v);
if (showScores) {
- System.out.printf("%d,%.3f,%.6f\n", target, score,
lr.logLikelihood(target, v));
+ output.printf("%d,%.3f,%.6f\n", target, score,
lr.logLikelihood(target, v));
}
collector.add(target, score);
line = in.readLine();
}
if (showAuc) {
- System.out.printf("AUC = %.2f\n", collector.auc());
+ output.printf("AUC = %.2f\n", collector.auc());
}
if (showConfusion) {
Matrix m = collector.confusion();
- System.out.printf("confusion: [[%.1f, %.1f], [%.1f, %.1f]]\n",
- m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ output.printf("confusion: [[%.1f, %.1f], [%.1f, %.1f]]\n",
+ m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
m = collector.entropy();
- System.out.printf("entropy: [[%.1f, %.1f], [%.1f, %.1f]]\n",
- m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+ output.printf("entropy: [[%.1f, %.1f], [%.1f, %.1f]]\n",
+ m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
}
}
}
Modified:
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java?rev=1059735&r1=1059734&r2=1059735&view=diff
==============================================================================
---
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
(original)
+++
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
Mon Jan 17 03:01:26 2011
@@ -36,6 +36,7 @@ import java.io.FileWriter;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.OutputStreamWriter;
+import java.io.PrintStream;
import java.net.URL;
import java.util.List;
@@ -52,6 +53,7 @@ public final class TrainLogistic {
private static int passes;
private static boolean scores;
private static OnlineLogisticRegression model;
+ static PrintStream output = System.out;
private TrainLogistic() {
}
@@ -87,8 +89,8 @@ public final class TrainLogistic {
}
double p = lr.classifyScalar(input);
if (scores) {
- System.out.printf("%10d %2d %10.2f %2.4f %10.4f %10.4f\n",
- samples, targetValue, lr.currentLearningRate(), p, logP,
logPEstimate);
+ output.printf("%10d %2d %10.2f %2.4f %10.4f %10.4f\n",
+ samples, targetValue, lr.currentLearningRate(), p, logP,
logPEstimate);
}
// now update model
@@ -106,29 +108,29 @@ public final class TrainLogistic {
modelOutput.close();
}
- System.out.printf("%d\n", lmp.getNumFeatures());
- System.out.printf("%s ~ ", lmp.getTargetVariable());
+ output.printf("%d\n", lmp.getNumFeatures());
+ output.printf("%s ~ ", lmp.getTargetVariable());
String sep = "";
for (String v : csv.getPredictors()) {
double weight = predictorWeight(lr, 0, csv, v);
if (weight != 0) {
- System.out.printf("%s%.3f*%s", sep, weight, v);
+ output.printf("%s%.3f*%s", sep, weight, v);
sep = " + ";
}
}
- System.out.printf("\n");
+ output.printf("\n");
model = lr;
for (int row = 0; row < lr.getBeta().numRows(); row++) {
for (String key : csv.getTraceDictionary().keySet()) {
double weight = predictorWeight(lr, row, csv, key);
if (weight != 0) {
- System.out.printf("%20s %.5f\n", key, weight);
+ output.printf("%20s %.5f\n", key, weight);
}
}
for (int column = 0; column < lr.getBeta().numCols(); column++) {
- System.out.printf("%15.9f ", lr.getBeta().get(row, column));
+ output.printf("%15.9f ", lr.getBeta().get(row, column));
}
- System.out.println();
+ output.println();
}
}
}
Modified:
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java?rev=1059735&r1=1059734&r2=1059735&view=diff
==============================================================================
---
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
(original)
+++
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
Mon Jan 17 03:01:26 2011
@@ -20,6 +20,7 @@ package org.apache.mahout.classifier.sgd
import com.google.common.base.CharMatcher;
import com.google.common.base.Charsets;
import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Iterables;
import com.google.common.collect.Sets;
import com.google.common.io.Resources;
@@ -29,8 +30,15 @@ import org.apache.mahout.math.DenseVecto
import org.apache.mahout.math.Vector;
import org.junit.Test;
+import java.io.ByteArrayOutputStream;
+import java.io.FileReader;
import java.io.IOException;
+import java.io.PrintStream;
+import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
import java.util.List;
+import java.util.Map;
import java.util.Set;
public class TrainLogisticTest extends MahoutTestCase {
@@ -39,17 +47,22 @@ public class TrainLogisticTest extends M
Splitter.on(CharMatcher.BREAKING_WHITESPACE).trimResults().omitEmptyStrings();
@Test
- public void testMain() throws IOException {
+ public void example13_1() throws IOException, NoSuchFieldException,
IllegalAccessException, InvocationTargetException, NoSuchMethodException {
String outputFile = getTestTempFile("model").getAbsolutePath();
- String inputFile = "donut.csv";
- String[] args = Iterables.toArray(ON_WHITE_SPACE.split(
- "--input " +
- inputFile +
- " --output " +
- outputFile +
- " --target color --categories 2 " +
- "--predictors x y --types numeric --features 20 --passes 100 --rate 50
"), String.class);
- TrainLogistic.main(args);
+
+ String trainOut = runMain(TrainLogistic.class, new String[]{
+ "--input", "donut.csv",
+ " --output", outputFile,
+ " --target", "color", "--categories", "2" +
+ "--predictors", "x", "y",
+ "--types", "numeric",
+ "--features", "20",
+ "--passes", "100",
+ "--rate", "50"
+ });
+ assertTrue(trainOut.contains("x -0.7"));
+ assertTrue(trainOut.contains("y -0.4"));
+
LogisticModelParameters lmp = TrainLogistic.getParameters();
assertEquals(1.0e-4, lmp.getLambda(), 1.0e-9);
assertEquals(20, lmp.getNumFeatures());
@@ -59,10 +72,76 @@ public class TrainLogisticTest extends M
assertEquals("[1, 2]",
Sets.newTreeSet(csv.getTargetCategories()).toString());
assertEquals("[Intercept Term, x, y]",
Sets.newTreeSet(csv.getPredictors()).toString());
-
+ // verify model by building dissector
AbstractVectorClassifier model = TrainLogistic.getModel();
+ List<String> data =
Resources.readLines(Resources.getResource("donut.csv"), Charsets.UTF_8);
+ Map<String, Double> expectedValues = ImmutableMap.of("x", -0.7, "y",
-0.43, "Intercept Term", -0.15);
+ verifyModel(lmp, csv, data, model, expectedValues);
+
+ // test saved model
+ LogisticModelParameters lmpOut = LogisticModelParameters.loadFrom(new
FileReader(outputFile));
+ CsvRecordFactory csvOut = lmpOut.getCsvRecordFactory();
+ csvOut.firstLine(data.get(0));
+ OnlineLogisticRegression lrOut = lmpOut.createRegression();
+ verifyModel(lmpOut, csvOut, data, lrOut, expectedValues);
+
+ String output = runMain(RunLogistic.class, new String[]{"--input",
"donut.csv", "--model", outputFile, "--auc", "--confusion"});
+ assertTrue(output.contains("AUC = 0.57"));
+ assertTrue(output.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
+ }
+
+ @Test
+ public void example13_2() throws InvocationTargetException, IOException,
NoSuchMethodException, NoSuchFieldException, IllegalAccessException {
+ String outputFile = getTestTempFile("model").getAbsolutePath();
+ String trainOut = runMain(TrainLogistic.class, new String[]{
+ "--input", "donut.csv", "--output", outputFile,
+ "--target", "color", "--categories", "2",
+ "--predictors", "x", "y", "a", "b", "c", "--types", "numeric",
+ "--features", "20", "--passes", "100", "--rate", "50"
+ });
+
+ assertTrue(trainOut.contains("a 0."));
+ assertTrue(trainOut.contains("b -1."));
+ assertTrue(trainOut.contains("c -25."));
+
+ String output = runMain(RunLogistic.class, new String[]{"--input",
"donut.csv", "--model", outputFile, "--auc", "--confusion"});
+ assertTrue(output.contains("AUC = 1.00"));
+
+ String heldout = runMain(RunLogistic.class, new String[]{"--input",
"donut-test.csv", "--model", outputFile, "--auc", "--confusion"});
+ assertTrue(heldout.contains("AUC = 0.9"));
+ }
+
+ /**
+ * Runs a class with a public static void main method. We assume that there
is an accessible
+ * field named "output" that we can change to redirect output.
+ *
+ *
+ * @param clazz contains the main method.
+ * @param args contains the command line arguments
+ * @return The contents to standard out as a string.
+ * @throws IOException Not possible, but must be declared.
+ * @throws NoSuchFieldException If there isn't an output field.
+ * @throws IllegalAccessException If the output field isn't
accessible by us.
+ * @throws NoSuchMethodException If there isn't a main method.
+ * @throws InvocationTargetException If the main method throws an
exception.
+ */
+ private String runMain(Class clazz, String[] args) throws IOException,
NoSuchFieldException, IllegalAccessException, NoSuchMethodException,
InvocationTargetException {
+ ByteArrayOutputStream trainOutput = new ByteArrayOutputStream();
+ PrintStream printStream = new PrintStream(trainOutput);
+
+ Field outputField = clazz.getDeclaredField("output");
+ Method main = clazz.getMethod("main", args.getClass());
+
+ outputField.set(null, printStream);
+ Object[] argList = {args};
+ main.invoke(null, argList);
+ printStream.close();
+
+ return new String(trainOutput.toByteArray(), Charsets.UTF_8);
+ }
+
+ private void verifyModel(LogisticModelParameters lmp, CsvRecordFactory csv,
List<String> data, AbstractVectorClassifier model, Map<String, Double>
expectedValues) {
ModelDissector md = new ModelDissector();
- List<String> data = Resources.readLines(Resources.getResource(inputFile),
Charsets.UTF_8);
for (String line : data.subList(1, data.size())) {
Vector v = new DenseVector(lmp.getNumFeatures());
csv.getTraceDictionary().clear();
@@ -70,12 +149,13 @@ public class TrainLogisticTest extends M
md.update(v, csv.getTraceDictionary(), model);
}
+ // check right variables are present
List<ModelDissector.Weight> weights = md.summary(10);
- Set<String> expected = Sets.newHashSet("x", "y", "Intercept Term");
+ Set<String> expected = Sets.newHashSet(expectedValues.keySet());
for (ModelDissector.Weight weight : weights) {
assertTrue(expected.remove(weight.getFeature()));
+ assertEquals(expectedValues.get(weight.getFeature()),
weight.getWeight(), 0.1);
}
assertEquals(0, expected.size());
- System.out.printf("%s\n", weights);
}
}