Author: srowen
Date: Sat Jun  4 19:37:27 2011
New Revision: 1131476

URL: http://svn.apache.org/viewvc?rev=1131476&view=rev
Log:
MAHOUT-702 add passive-aggressive learner

Added:
    
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
Modified:
    
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java

Added: 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java?rev=1131476&view=auto
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
 (added)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
 Sat Jun  4 19:37:27 2011
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Online passive aggressive learner that tries to minimize the label ranking 
hinge loss.
+ * Implements a multi-class linear classifier minimizing rank loss.
+ *  based on "Online passive aggressive algorithms" by Cramer et al, 2006.
+ *  Note: Its better to use classifyNoLink because the loss function is based
+ *  on ensuring that the score of the good label is larger than the next
+ *  highest label by some margin. The conversion to probability is just done
+ *  by exponentiating and dividing by the sum and is empirical at best.
+ *  Your features should be pre-normalized in some sensible range, for example,
+ *  by subtracting the mean and standard deviation, if they are very
+ *  different in magnitude from each other.
+ */
+public class PassiveAggressive extends AbstractVectorClassifier implements 
OnlineLearner, Writable {
+
+  private static final Logger log = 
LoggerFactory.getLogger(PassiveAggressive.class);
+
+  public static final int WRITABLE_VERSION = 1;
+
+  // the learning rate of the algorithm
+  private double learningRate = 0.1;
+
+  // loss statistics.
+  private int lossCount = 0;
+  private double lossSum = 0;
+
+  // coefficients for the classification.  This is a dense matrix
+  // that is (numCategories ) x numFeatures
+  private Matrix weights;
+
+  // number of categories we are classifying.
+  private int numCategories;
+
+  public PassiveAggressive(int numCategories, int numFeatures) {
+    this.numCategories = numCategories;
+    weights = new DenseMatrix(numCategories, numFeatures);
+    weights.assign(0.0);
+  }
+
+  /**
+   * Chainable configuration option.
+   *
+   * @param learningRate New value of initial learning rate.
+   * @return This, so other configurations can be chained.
+   */
+  public PassiveAggressive learningRate(double learningRate) {
+    this.learningRate = learningRate;
+    return this;
+  }
+
+  public void copyFrom(PassiveAggressive other) {
+    learningRate = other.learningRate;
+    numCategories = other.numCategories;
+    weights = other.weights;
+  }
+
+  @Override
+  public int numCategories() {
+    return numCategories;
+  }
+
+  @Override
+  public Vector classify(Vector instance) {
+    Vector result = (DenseVector) classifyNoLink(instance);
+    // Convert to probabilities by exponentiation.
+    double max = result.maxValue();
+    result.assign(Functions.minus(max)).assign(Functions.EXP);
+    result = result.divide(result.norm(1));
+
+    return result.viewPart(1, result.size() - 1);
+  }
+
+  @Override
+  public Vector classifyNoLink(Vector instance) {
+    Vector result = new DenseVector(weights.numRows());
+    result.assign(0);
+    for (int i = 0; i < weights.numRows(); i++) {
+      result.setQuick(i, weights.viewRow(i).dot(instance));
+    }
+    return result;
+  }
+
+  @Override
+  public double classifyScalar(Vector instance) {
+    double v1 = weights.viewRow(0).dot(instance);
+    double v2 = weights.viewRow(1).dot(instance);
+    v1 = Math.exp(v1);
+    v2 = Math.exp(v2);
+    return v2 / (v1 + v2);
+  }
+
+  public int numFeatures() {
+    return weights.numCols();
+  }
+
+  public PassiveAggressive copy() {
+    close();
+    PassiveAggressive r = new PassiveAggressive(numCategories(), 
numFeatures());
+    r.copyFrom(this);
+    return r;
+  }
+
+  @Override
+  public void write(DataOutput out) throws IOException {
+    out.writeInt(WRITABLE_VERSION);
+    out.writeDouble(learningRate);
+    out.writeInt(numCategories);
+    MatrixWritable.writeMatrix(out, weights);
+  }
+
+  @Override
+  public void readFields(DataInput in) throws IOException {
+    int version = in.readInt();
+    if (version == WRITABLE_VERSION) {
+      learningRate = in.readDouble();
+      numCategories = in.readInt();
+      weights = MatrixWritable.readMatrix(in);
+    } else {
+      throw new IOException("Incorrect object version, wanted " + 
WRITABLE_VERSION + " got " + version);
+    }
+  }
+
+  @Override
+  public void close() {
+      // This is an online classifier, nothing to do.
+  }
+
+  @Override
+  public void train(long trackingKey, String groupKey, int actual, Vector 
instance) {
+    if (lossCount > 1000) {
+      log.info("Avg. Loss = {}", lossSum / lossCount);
+      lossCount = 0;
+      lossSum = 0;
+    }
+    Vector result = classifyNoLink(instance);
+    double my_score = result.get(actual);
+    // Find the highest score that is not actual.
+    int other_idx = result.maxValueIndex();
+    double other_value = result.get(other_idx);
+    if (other_idx == actual) {
+      result.setQuick(other_idx, Double.NEGATIVE_INFINITY);
+      other_idx = result.maxValueIndex();
+      other_value = result.get(other_idx);
+    }
+    double loss = 1.0 - my_score + other_value;
+    lossCount += 1;
+    if (loss >= 0) {
+      lossSum += loss;
+      double tau = loss / (instance.dot(instance) + 0.5 / learningRate);
+      Vector delta = instance.clone();
+      delta.assign(Functions.mult(tau));
+      delta.addTo(weights.getRow(actual));
+      delta.assign(Functions.mult(-1));
+      delta.addTo(weights.getRow(other_idx));
+    }
+  }
+
+  @Override
+  public void train(long trackingKey, int actual, Vector instance) {
+    train(trackingKey, null, actual, instance);
+  }
+
+  @Override
+  public void train(int actual, Vector instance) {
+    train(0, null, actual, instance);
+  }
+
+}

Added: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java?rev=1131476&view=auto
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
 (added)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
 Sat Jun  4 19:37:27 2011
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public abstract class OnlineBaseTest extends MahoutTestCase {
+
+  private Matrix input;
+
+  protected Matrix getInput() {
+    return input;
+  }
+
+  protected Vector readStandardData() throws IOException {
+    // 60 test samples.  First column is constant.  Second and third are 
normally distributed from
+    // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59).  The 
first 30 rows have a
+    // target variable of 0, the last 30 a target of 1.  The remaining columns 
are are random noise.
+    input = readCsv("sgd.csv");
+
+    // regenerate the target variable
+    Vector target = new DenseVector(60);
+    target.assign(0);
+    target.viewPart(30, 30).assign(1);
+    return target;
+  }
+
+  protected static void train(Matrix input, Vector target, OnlineLearner lr) {
+    RandomUtils.useTestSeed();
+    Random gen = RandomUtils.getRandom();
+
+    // train on samples in random order (but only one pass)
+    for (int row : permute(gen, 60)) {
+      lr.train((int) target.get(row), input.getRow(row));
+    }
+    lr.close();
+  }
+
+  protected static void test(Matrix input, Vector target, 
AbstractVectorClassifier lr,
+                             double expected_mean_error, double 
expected_absolute_error) {
+    // now test the accuracy
+    Matrix tmp = lr.classify(input);
+    // mean(abs(tmp - target))
+    double meanAbsoluteError = 
tmp.getColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
+
+    // max(abs(tmp - target)
+    double maxAbsoluteError = 
tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
+
+    System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, 
maxAbsoluteError);
+    assertEquals(0, meanAbsoluteError , expected_mean_error);
+    assertEquals(0, maxAbsoluteError, expected_absolute_error);
+
+    // convenience methods should give the same results
+    Vector v = lr.classifyScalar(input);
+    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
+    v = lr.classifyFull(input).getColumn(1);
+    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
+  }
+
+  /**
+   * Permute the integers from 0 ... max-1
+   *
+   * @param gen The random number generator to use.
+   * @param max The number of integers to permute
+   * @return An array of jumbled integer values
+   */
+  protected static int[] permute(Random gen, int max) {
+    int[] permutation = new int[max];
+    permutation[0] = 0;
+    for (int i = 1; i < max; i++) {
+      int n = gen.nextInt(i + 1);
+      if (n == i) {
+        permutation[i] = i;
+      } else {
+        permutation[i] = permutation[n];
+        permutation[n] = i;
+      }
+    }
+    return permutation;
+  }
+
+
+  /**
+   * Reads a file containing CSV data.  This isn't implemented quite the way 
you might like for a
+   * real program, but does the job for reading test data.  Most notably, it 
will only read numbers,
+   * not quoted strings.
+   *
+   * @param resourceName Where to get the data.
+   * @return A matrix of the results.
+   * @throws IOException If there is an error reading the data
+   */
+  protected static Matrix readCsv(String resourceName) throws IOException {
+    Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));
+
+    Readable isr = new 
InputStreamReader(Resources.getResource(resourceName).openStream(), 
Charsets.UTF_8);
+    List<String> data = CharStreams.readLines(isr);
+    String first = data.get(0);
+    data = data.subList(1, data.size());
+
+    List<String> values = Lists.newArrayList(onCommas.split(first));
+    Matrix r = new DenseMatrix(data.size(), values.size());
+
+    int column = 0;
+    Map<String, Integer> labels = Maps.newHashMap();
+    for (String value : values) {
+      labels.put(value, column);
+      column++;
+    }
+    r.setColumnLabelBindings(labels);
+
+    int row = 0;
+    for (String line : data) {
+      column = 0;
+      values = Lists.newArrayList(onCommas.split(line));
+      for (String value : values) {
+        r.set(row, column, Double.parseDouble(value));
+        column++;
+      }
+      row++;
+    }
+
+    return r;
+  }
+}

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1131476&r1=1131475&r2=1131476&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
 Sat Jun  4 19:37:27 2011
@@ -17,33 +17,16 @@
 
 package org.apache.mahout.classifier.sgd;
 
-import com.google.common.base.CharMatcher;
-import com.google.common.base.Charsets;
-import com.google.common.base.Splitter;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.common.io.CharStreams;
-import com.google.common.io.Resources;
-import org.apache.mahout.classifier.AbstractVectorClassifier;
-import org.apache.mahout.classifier.OnlineLearner;
-import org.apache.mahout.common.MahoutTestCase;
 import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.DenseMatrix;
 import org.apache.mahout.math.DenseVector;
 import org.apache.mahout.math.Matrix;
 import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.function.Functions;
 import org.junit.Test;
 
 import java.io.IOException;
-import java.io.InputStreamReader;
-import java.util.List;
-import java.util.Map;
 import java.util.Random;
 
-public final class OnlineLogisticRegressionTest extends MahoutTestCase {
-
-  private Matrix input;
+public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
 
   /**
    * The CrossFoldLearner is probably the best learner to use for new 
applications.
@@ -58,10 +41,10 @@ public final class OnlineLogisticRegress
             .learningRate(50);
 
 
-    train(input, target, lr);
+    train(getInput(), target, lr);
 
     System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
-    test(input, target, lr);
+    test(getInput(), target, lr, 0.05, 0.3);
 
   }
 
@@ -162,116 +145,8 @@ public final class OnlineLogisticRegress
             .lambda(1 * 1.0e-3)
             .learningRate(50);
 
-    train(input, target, lr);
-    test(input, target, lr);
-  }
-
-  private Vector readStandardData() throws IOException {
-    // 60 test samples.  First column is constant.  Second and third are 
normally distributed from
-    // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59).  The 
first 30 rows have a
-    // target variable of 0, the last 30 a target of 1.  The remaining columns 
are are random noise.
-    input = readCsv("sgd.csv");
-
-    // regenerate the target variable
-    Vector target = new DenseVector(60);
-    target.assign(0);
-    target.viewPart(30, 30).assign(1);
-    return target;
-  }
-
-  private static void train(Matrix input, Vector target, OnlineLearner lr) {
-    RandomUtils.useTestSeed();
-    Random gen = RandomUtils.getRandom();
-
-    // train on samples in random order (but only one pass)
-    for (int row : permute(gen, 60)) {
-      lr.train((int) target.get(row), input.getRow(row));
-    }
-    lr.close();
-  }
-
-  private static void test(Matrix input, Vector target, 
AbstractVectorClassifier lr) {
-    // now test the accuracy
-    Matrix tmp = lr.classify(input);
-    // mean(abs(tmp - target))
-    double meanAbsoluteError = 
tmp.getColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
-
-    // max(abs(tmp - target)
-    double maxAbsoluteError = 
tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
-
-    System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, 
maxAbsoluteError);
-    assertEquals(0, meanAbsoluteError , 0.05);
-    assertEquals(0, maxAbsoluteError, 0.3);
-
-    // convenience methods should give the same results
-    Vector v = lr.classifyScalar(input);
-    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
-    v = lr.classifyFull(input).getColumn(1);
-    assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
+    train(getInput(), target, lr);
+    test(getInput(), target, lr, 0.05, 0.3);
   }
 
-  /**
-   * Permute the integers from 0 ... max-1
-   *
-   * @param gen The random number generator to use.
-   * @param max The number of integers to permute
-   * @return An array of jumbled integer values
-   */
-  private static int[] permute(Random gen, int max) {
-    int[] permutation = new int[max];
-    permutation[0] = 0;
-    for (int i = 1; i < max; i++) {
-      int n = gen.nextInt(i + 1);
-      if (n == i) {
-        permutation[i] = i;
-      } else {
-        permutation[i] = permutation[n];
-        permutation[n] = i;
-      }
-    }
-    return permutation;
-  }
-
-
-  /**
-   * Reads a file containing CSV data.  This isn't implemented quite the way 
you might like for a
-   * real program, but does the job for reading test data.  Most notably, it 
will only read numbers,
-   * not quoted strings.
-   *
-   * @param resourceName Where to get the data.
-   * @return A matrix of the results.
-   * @throws IOException If there is an error reading the data
-   */
-  private static Matrix readCsv(String resourceName) throws IOException {
-    Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));
-
-    Readable isr = new 
InputStreamReader(Resources.getResource(resourceName).openStream(), 
Charsets.UTF_8);
-    List<String> data = CharStreams.readLines(isr);
-    String first = data.get(0);
-    data = data.subList(1, data.size());
-
-    List<String> values = Lists.newArrayList(onCommas.split(first));
-    Matrix r = new DenseMatrix(data.size(), values.size());
-
-    int column = 0;
-    Map<String, Integer> labels = Maps.newHashMap();
-    for (String value : values) {
-      labels.put(value, column);
-      column++;
-    }
-    r.setColumnLabelBindings(labels);
-
-    int row = 0;
-    for (String line : data) {
-      column = 0;
-      values = Lists.newArrayList(onCommas.split(line));
-      for (String value : values) {
-        r.set(row, column, Double.parseDouble(value));
-        column++;
-      }
-      row++;
-    }
-
-    return r;
-  }
-}
+}
\ No newline at end of file

Added: 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java?rev=1131476&view=auto
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
 (added)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
 Sat Jun  4 19:37:27 2011
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public final class PassiveAggressiveTest extends OnlineBaseTest {
+
+  @Test
+  public void testPassiveAggressive() throws IOException {
+    Vector target = readStandardData();
+    PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1);
+    train(getInput(), target, pa);
+    test(getInput(), target, pa, 0.1, 0.3);
+  }
+
+}


Reply via email to