Author: tdunning
Date: Mon Jan 17 03:01:26 2011
New Revision: 1059735

URL: http://svn.apache.org/viewvc?rev=1059735&view=rev
Log:
MAHOUT-496 - Added test for examples from chapter 13 of MiA

Modified:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
    
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java?rev=1059735&r1=1059734&r2=1059735&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/RunLogistic.java
 Mon Jan 17 03:01:26 2011
@@ -33,6 +33,7 @@ import org.apache.mahout.classifier.eval
 import java.io.BufferedReader;
 import java.io.File;
 import java.io.IOException;
+import java.io.PrintStream;
 
 public final class RunLogistic {
 
@@ -41,6 +42,7 @@ public final class RunLogistic {
   private static boolean showAuc;
   private static boolean showScores;
   private static boolean showConfusion;
+  static PrintStream output = System.out;
 
   private RunLogistic() {
   }
@@ -62,29 +64,29 @@ public final class RunLogistic {
       csv.firstLine(line);
       line = in.readLine();
       if (showScores) {
-        System.out.printf("\"%s\",\"%s\",\"%s\"\n", "target", "model-output", 
"log-likelihood");
+        output.printf("\"%s\",\"%s\",\"%s\"\n", "target", "model-output", 
"log-likelihood");
       }
       while (line != null) {
         Vector v = new SequentialAccessSparseVector(lmp.getNumFeatures());
         int target = csv.processLine(line, v);
         double score = lr.classifyScalar(v);
         if (showScores) {
-          System.out.printf("%d,%.3f,%.6f\n", target, score, 
lr.logLikelihood(target, v));
+          output.printf("%d,%.3f,%.6f\n", target, score, 
lr.logLikelihood(target, v));
         }
         collector.add(target, score);
         line = in.readLine();
       }
 
       if (showAuc) {
-        System.out.printf("AUC = %.2f\n", collector.auc());
+        output.printf("AUC = %.2f\n", collector.auc());
       }
       if (showConfusion) {
         Matrix m = collector.confusion();
-        System.out.printf("confusion: [[%.1f, %.1f], [%.1f, %.1f]]\n",
-            m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+        output.printf("confusion: [[%.1f, %.1f], [%.1f, %.1f]]\n",
+          m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
         m = collector.entropy();
-        System.out.printf("entropy: [[%.1f, %.1f], [%.1f, %.1f]]\n",
-            m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
+        output.printf("entropy: [[%.1f, %.1f], [%.1f, %.1f]]\n",
+          m.get(0, 0), m.get(1, 0), m.get(0, 1), m.get(1, 1));
       }
     }
   }

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java?rev=1059735&r1=1059734&r2=1059735&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
 Mon Jan 17 03:01:26 2011
@@ -36,6 +36,7 @@ import java.io.FileWriter;
 import java.io.IOException;
 import java.io.InputStreamReader;
 import java.io.OutputStreamWriter;
+import java.io.PrintStream;
 import java.net.URL;
 import java.util.List;
 
@@ -52,6 +53,7 @@ public final class TrainLogistic {
   private static int passes;
   private static boolean scores;
   private static OnlineLogisticRegression model;
+  static PrintStream output = System.out;
 
   private TrainLogistic() {
   }
@@ -87,8 +89,8 @@ public final class TrainLogistic {
           }
           double p = lr.classifyScalar(input);
           if (scores) {
-            System.out.printf("%10d %2d %10.2f %2.4f %10.4f %10.4f\n",
-                samples, targetValue, lr.currentLearningRate(), p, logP, 
logPEstimate);
+            output.printf("%10d %2d %10.2f %2.4f %10.4f %10.4f\n",
+              samples, targetValue, lr.currentLearningRate(), p, logP, 
logPEstimate);
           }
 
           // now update model
@@ -106,29 +108,29 @@ public final class TrainLogistic {
         modelOutput.close();
       }
       
-      System.out.printf("%d\n", lmp.getNumFeatures());
-      System.out.printf("%s ~ ", lmp.getTargetVariable());
+      output.printf("%d\n", lmp.getNumFeatures());
+      output.printf("%s ~ ", lmp.getTargetVariable());
       String sep = "";
       for (String v : csv.getPredictors()) {
         double weight = predictorWeight(lr, 0, csv, v);
         if (weight != 0) {
-          System.out.printf("%s%.3f*%s", sep, weight, v);
+          output.printf("%s%.3f*%s", sep, weight, v);
           sep = " + ";
         }
       }
-      System.out.printf("\n");
+      output.printf("\n");
       model = lr;
       for (int row = 0; row < lr.getBeta().numRows(); row++) {
         for (String key : csv.getTraceDictionary().keySet()) {
           double weight = predictorWeight(lr, row, csv, key);
           if (weight != 0) {
-            System.out.printf("%20s %.5f\n", key, weight);
+            output.printf("%20s %.5f\n", key, weight);
           }
         }
         for (int column = 0; column < lr.getBeta().numCols(); column++) {
-          System.out.printf("%15.9f ", lr.getBeta().get(row, column));
+          output.printf("%15.9f ", lr.getBeta().get(row, column));
         }
-        System.out.println();
+        output.println();
       }
     }
   }

Modified: 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java?rev=1059735&r1=1059734&r2=1059735&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
 (original)
+++ 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
 Mon Jan 17 03:01:26 2011
@@ -20,6 +20,7 @@ package org.apache.mahout.classifier.sgd
 import com.google.common.base.CharMatcher;
 import com.google.common.base.Charsets;
 import com.google.common.base.Splitter;
+import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Iterables;
 import com.google.common.collect.Sets;
 import com.google.common.io.Resources;
@@ -29,8 +30,15 @@ import org.apache.mahout.math.DenseVecto
 import org.apache.mahout.math.Vector;
 import org.junit.Test;
 
+import java.io.ByteArrayOutputStream;
+import java.io.FileReader;
 import java.io.IOException;
+import java.io.PrintStream;
+import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
 
 public class TrainLogisticTest extends MahoutTestCase {
@@ -39,17 +47,22 @@ public class TrainLogisticTest extends M
       
Splitter.on(CharMatcher.BREAKING_WHITESPACE).trimResults().omitEmptyStrings();
 
   @Test
-  public void testMain() throws IOException {
+  public void example13_1() throws IOException, NoSuchFieldException, 
IllegalAccessException, InvocationTargetException, NoSuchMethodException {
     String outputFile = getTestTempFile("model").getAbsolutePath();
-    String inputFile = "donut.csv";
-    String[] args = Iterables.toArray(ON_WHITE_SPACE.split(
-      "--input " +
-        inputFile +
-        " --output " +
-        outputFile +
-        " --target color --categories 2 " +
-        "--predictors x y --types numeric --features 20 --passes 100 --rate 50 
"), String.class);
-    TrainLogistic.main(args);
+
+    String trainOut = runMain(TrainLogistic.class, new String[]{
+      "--input", "donut.csv",
+      " --output", outputFile,
+      " --target", "color", "--categories", "2" +
+        "--predictors", "x", "y",
+      "--types", "numeric",
+      "--features", "20",
+      "--passes", "100",
+      "--rate", "50"
+    });
+    assertTrue(trainOut.contains("x -0.7"));
+    assertTrue(trainOut.contains("y -0.4"));
+
     LogisticModelParameters lmp = TrainLogistic.getParameters();
     assertEquals(1.0e-4, lmp.getLambda(), 1.0e-9);
     assertEquals(20, lmp.getNumFeatures());
@@ -59,10 +72,76 @@ public class TrainLogisticTest extends M
     assertEquals("[1, 2]", 
Sets.newTreeSet(csv.getTargetCategories()).toString());
     assertEquals("[Intercept Term, x, y]", 
Sets.newTreeSet(csv.getPredictors()).toString());
 
-
+    // verify model by building dissector
     AbstractVectorClassifier model = TrainLogistic.getModel();
+    List<String> data = 
Resources.readLines(Resources.getResource("donut.csv"), Charsets.UTF_8);
+    Map<String, Double> expectedValues = ImmutableMap.of("x", -0.7, "y", 
-0.43, "Intercept Term", -0.15);
+    verifyModel(lmp, csv, data, model, expectedValues);
+
+    // test saved model
+    LogisticModelParameters lmpOut = LogisticModelParameters.loadFrom(new 
FileReader(outputFile));
+    CsvRecordFactory csvOut = lmpOut.getCsvRecordFactory();
+    csvOut.firstLine(data.get(0));
+    OnlineLogisticRegression lrOut = lmpOut.createRegression();
+    verifyModel(lmpOut, csvOut, data, lrOut, expectedValues);
+
+    String output = runMain(RunLogistic.class, new String[]{"--input", 
"donut.csv", "--model", outputFile, "--auc", "--confusion"});
+    assertTrue(output.contains("AUC = 0.57"));
+    assertTrue(output.contains("confusion: [[27.0, 13.0], [0.0, 0.0]]"));
+  }
+
+  @Test
+  public void example13_2() throws InvocationTargetException, IOException, 
NoSuchMethodException, NoSuchFieldException, IllegalAccessException {
+    String outputFile = getTestTempFile("model").getAbsolutePath();
+    String trainOut = runMain(TrainLogistic.class, new String[]{
+      "--input", "donut.csv", "--output", outputFile,
+      "--target", "color", "--categories", "2",
+      "--predictors", "x", "y", "a", "b", "c", "--types", "numeric",
+      "--features", "20", "--passes", "100", "--rate", "50"
+    });
+
+    assertTrue(trainOut.contains("a 0."));
+    assertTrue(trainOut.contains("b -1."));
+    assertTrue(trainOut.contains("c -25."));
+
+    String output = runMain(RunLogistic.class, new String[]{"--input", 
"donut.csv", "--model", outputFile, "--auc", "--confusion"});
+    assertTrue(output.contains("AUC = 1.00"));
+
+    String heldout = runMain(RunLogistic.class, new String[]{"--input", 
"donut-test.csv", "--model", outputFile, "--auc", "--confusion"});
+    assertTrue(heldout.contains("AUC = 0.9"));
+  }
+
+  /**
+   * Runs a class with a public static void main method.  We assume that there 
is an accessible
+   * field named "output" that we can change to redirect output.
+   *
+   *
+   * @param clazz   contains the main method.
+   * @param args    contains the command line arguments
+   * @return The contents to standard out as a string.
+   * @throws IOException                   Not possible, but must be declared.
+   * @throws NoSuchFieldException          If there isn't an output field.
+   * @throws IllegalAccessException        If the output field isn't 
accessible by us.
+   * @throws NoSuchMethodException         If there isn't a main method.
+   * @throws InvocationTargetException     If the main method throws an 
exception.
+   */
+  private String runMain(Class clazz, String[] args) throws IOException, 
NoSuchFieldException, IllegalAccessException, NoSuchMethodException, 
InvocationTargetException {
+    ByteArrayOutputStream trainOutput = new ByteArrayOutputStream();
+    PrintStream printStream = new PrintStream(trainOutput);
+
+    Field outputField = clazz.getDeclaredField("output");
+    Method main = clazz.getMethod("main", args.getClass());
+
+    outputField.set(null, printStream);
+    Object[] argList = {args};
+    main.invoke(null, argList);
+    printStream.close();
+
+    return new String(trainOutput.toByteArray(), Charsets.UTF_8);
+  }
+
+  private void verifyModel(LogisticModelParameters lmp, CsvRecordFactory csv, 
List<String> data, AbstractVectorClassifier model, Map<String, Double> 
expectedValues) {
     ModelDissector md = new ModelDissector();
-    List<String> data = Resources.readLines(Resources.getResource(inputFile), 
Charsets.UTF_8);
     for (String line : data.subList(1, data.size())) {
       Vector v = new DenseVector(lmp.getNumFeatures());
       csv.getTraceDictionary().clear();
@@ -70,12 +149,13 @@ public class TrainLogisticTest extends M
       md.update(v, csv.getTraceDictionary(), model);
     }
 
+    // check right variables are present
     List<ModelDissector.Weight> weights = md.summary(10);
-    Set<String> expected = Sets.newHashSet("x", "y", "Intercept Term");
+    Set<String> expected = Sets.newHashSet(expectedValues.keySet());
     for (ModelDissector.Weight weight : weights) {
       assertTrue(expected.remove(weight.getFeature()));
+      assertEquals(expectedValues.get(weight.getFeature()), 
weight.getWeight(), 0.1);
     }
     assertEquals(0, expected.size());
-    System.out.printf("%s\n", weights);
   }
 }


Reply via email to