Author: tdunning
Date: Wed Sep 22 17:28:20 2010
New Revision: 1000098

URL: http://svn.apache.org/viewvc?rev=1000098&view=rev
Log:
MAHOUT-511 - Fix regression in old-style logistic regression training.

Added:
    mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/
    
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
Modified:
    
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java

Modified: 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java?rev=1000098&r1=1000097&r2=1000098&view=diff
==============================================================================
--- 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
 (original)
+++ 
mahout/trunk/examples/src/main/java/org/apache/mahout/classifier/sgd/TrainLogistic.java
 Wed Sep 22 17:28:20 2010
@@ -51,6 +51,7 @@ public final class TrainLogistic {
 
   private static int passes;
   private static boolean scores;
+  private static OnlineLogisticRegression model;
 
   private TrainLogistic() {
   }
@@ -115,6 +116,7 @@ public final class TrainLogistic {
         }
       }
       System.out.printf("\n");
+      model = lr;
       for (int row = 0; row < lr.getBeta().numRows(); row++) {
         for (String key : csv.getTraceDictionary().keySet()) {
           double weight = predictorWeight(lr, row, csv, key);
@@ -285,6 +287,14 @@ public final class TrainLogistic {
     return Double.parseDouble((String) cmdLine.getValue(op));
   }
 
+  public static OnlineLogisticRegression getModel() {
+    return model;
+  }
+
+  public static LogisticModelParameters getParameters() {
+    return lmp;
+  }
+
   public static class InputOpener {
     private InputOpener() {
     }

Added: 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java?rev=1000098&view=auto
==============================================================================
--- 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
 (added)
+++ 
mahout/trunk/examples/src/test/java/org/apache/mahout/classifier/sgd/TrainLogisticTest.java
 Wed Sep 22 17:28:20 2010
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Sets;
+import com.google.common.io.Files;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.examples.MahoutTestCase;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.util.List;
+import java.util.Set;
+
+public class TrainLogisticTest extends MahoutTestCase {
+  Splitter onWhiteSpace = 
Splitter.on(CharMatcher.BREAKING_WHITESPACE).trimResults().omitEmptyStrings();
+  @Test
+  public void testMain() throws IOException {
+    String outputFile = "./model";
+    String inputFile = "donut.csv";
+    String[] args = Iterables.toArray(onWhiteSpace.split(
+      "--input " +
+        inputFile +
+        " --output " +
+        outputFile +
+        " --target color --categories 2 " +
+        "--predictors x y --types numeric --features 20 --passes 100 --rate 50 
"), String.class);
+    TrainLogistic.main(args);
+    LogisticModelParameters lmp = TrainLogistic.getParameters();
+    assertEquals(1e-4, lmp.getLambda(), 1e-9);
+    assertEquals(20, lmp.getNumFeatures());
+    assertEquals(true, lmp.useBias());
+    assertEquals("color", lmp.getTargetVariable());
+    CsvRecordFactory csv = lmp.getCsvRecordFactory();
+    assertEquals("[1, 2]", 
Sets.newTreeSet(csv.getTargetCategories()).toString());
+    assertEquals("[Intercept Term, x, y]", 
Sets.newTreeSet(csv.getPredictors()).toString());
+
+
+    AbstractVectorClassifier model = TrainLogistic.getModel();
+    ModelDissector md = new ModelDissector(2);
+    List<String> data = Resources.readLines(Resources.getResource(inputFile), 
Charsets.UTF_8);
+    for (String line : data.subList(1, data.size())) {
+      Vector v = new DenseVector(lmp.getNumFeatures());
+      csv.getTraceDictionary().clear();
+      csv.processLine(line, v);
+      md.update(v, csv.getTraceDictionary(), model);
+    }
+
+    List<ModelDissector.Weight> weights = md.summary(10);
+    Set<String> expected = Sets.newHashSet("x", "y", "Intercept Term");
+    for (ModelDissector.Weight weight : weights) {
+      assertTrue(expected.remove(weight.getFeature()));
+    }
+    assertEquals(0, expected.size());
+    System.out.printf("%s\n", weights);
+  }
+}


Reply via email to