Author: srowen
Date: Sat Jun 4 19:37:27 2011
New Revision: 1131476
URL: http://svn.apache.org/viewvc?rev=1131476&view=rev
Log:
MAHOUT-702 add passive-aggressive learner
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java?rev=1131476&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/classifier/sgd/PassiveAggressive.java
Sat Jun 4 19:37:27 2011
@@ -0,0 +1,202 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.MatrixWritable;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+/**
+ * Online passive aggressive learner that tries to minimize the label ranking
hinge loss.
+ * Implements a multi-class linear classifier minimizing rank loss.
+ * based on "Online passive aggressive algorithms" by Cramer et al, 2006.
+ * Note: Its better to use classifyNoLink because the loss function is based
+ * on ensuring that the score of the good label is larger than the next
+ * highest label by some margin. The conversion to probability is just done
+ * by exponentiating and dividing by the sum and is empirical at best.
+ * Your features should be pre-normalized in some sensible range, for example,
+ * by subtracting the mean and standard deviation, if they are very
+ * different in magnitude from each other.
+ */
+public class PassiveAggressive extends AbstractVectorClassifier implements
OnlineLearner, Writable {
+
+ private static final Logger log =
LoggerFactory.getLogger(PassiveAggressive.class);
+
+ public static final int WRITABLE_VERSION = 1;
+
+ // the learning rate of the algorithm
+ private double learningRate = 0.1;
+
+ // loss statistics.
+ private int lossCount = 0;
+ private double lossSum = 0;
+
+ // coefficients for the classification. This is a dense matrix
+ // that is (numCategories ) x numFeatures
+ private Matrix weights;
+
+ // number of categories we are classifying.
+ private int numCategories;
+
+ public PassiveAggressive(int numCategories, int numFeatures) {
+ this.numCategories = numCategories;
+ weights = new DenseMatrix(numCategories, numFeatures);
+ weights.assign(0.0);
+ }
+
+ /**
+ * Chainable configuration option.
+ *
+ * @param learningRate New value of initial learning rate.
+ * @return This, so other configurations can be chained.
+ */
+ public PassiveAggressive learningRate(double learningRate) {
+ this.learningRate = learningRate;
+ return this;
+ }
+
+ public void copyFrom(PassiveAggressive other) {
+ learningRate = other.learningRate;
+ numCategories = other.numCategories;
+ weights = other.weights;
+ }
+
+ @Override
+ public int numCategories() {
+ return numCategories;
+ }
+
+ @Override
+ public Vector classify(Vector instance) {
+ Vector result = (DenseVector) classifyNoLink(instance);
+ // Convert to probabilities by exponentiation.
+ double max = result.maxValue();
+ result.assign(Functions.minus(max)).assign(Functions.EXP);
+ result = result.divide(result.norm(1));
+
+ return result.viewPart(1, result.size() - 1);
+ }
+
+ @Override
+ public Vector classifyNoLink(Vector instance) {
+ Vector result = new DenseVector(weights.numRows());
+ result.assign(0);
+ for (int i = 0; i < weights.numRows(); i++) {
+ result.setQuick(i, weights.viewRow(i).dot(instance));
+ }
+ return result;
+ }
+
+ @Override
+ public double classifyScalar(Vector instance) {
+ double v1 = weights.viewRow(0).dot(instance);
+ double v2 = weights.viewRow(1).dot(instance);
+ v1 = Math.exp(v1);
+ v2 = Math.exp(v2);
+ return v2 / (v1 + v2);
+ }
+
+ public int numFeatures() {
+ return weights.numCols();
+ }
+
+ public PassiveAggressive copy() {
+ close();
+ PassiveAggressive r = new PassiveAggressive(numCategories(),
numFeatures());
+ r.copyFrom(this);
+ return r;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.writeInt(WRITABLE_VERSION);
+ out.writeDouble(learningRate);
+ out.writeInt(numCategories);
+ MatrixWritable.writeMatrix(out, weights);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ int version = in.readInt();
+ if (version == WRITABLE_VERSION) {
+ learningRate = in.readDouble();
+ numCategories = in.readInt();
+ weights = MatrixWritable.readMatrix(in);
+ } else {
+ throw new IOException("Incorrect object version, wanted " +
WRITABLE_VERSION + " got " + version);
+ }
+ }
+
+ @Override
+ public void close() {
+ // This is an online classifier, nothing to do.
+ }
+
+ @Override
+ public void train(long trackingKey, String groupKey, int actual, Vector
instance) {
+ if (lossCount > 1000) {
+ log.info("Avg. Loss = {}", lossSum / lossCount);
+ lossCount = 0;
+ lossSum = 0;
+ }
+ Vector result = classifyNoLink(instance);
+ double my_score = result.get(actual);
+ // Find the highest score that is not actual.
+ int other_idx = result.maxValueIndex();
+ double other_value = result.get(other_idx);
+ if (other_idx == actual) {
+ result.setQuick(other_idx, Double.NEGATIVE_INFINITY);
+ other_idx = result.maxValueIndex();
+ other_value = result.get(other_idx);
+ }
+ double loss = 1.0 - my_score + other_value;
+ lossCount += 1;
+ if (loss >= 0) {
+ lossSum += loss;
+ double tau = loss / (instance.dot(instance) + 0.5 / learningRate);
+ Vector delta = instance.clone();
+ delta.assign(Functions.mult(tau));
+ delta.addTo(weights.getRow(actual));
+ delta.assign(Functions.mult(-1));
+ delta.addTo(weights.getRow(other_idx));
+ }
+ }
+
+ @Override
+ public void train(long trackingKey, int actual, Vector instance) {
+ train(trackingKey, null, actual, instance);
+ }
+
+ @Override
+ public void train(int actual, Vector instance) {
+ train(0, null, actual, instance);
+ }
+
+}
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java?rev=1131476&view=auto
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
(added)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java
Sat Jun 4 19:37:27 2011
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import com.google.common.base.CharMatcher;
+import com.google.common.base.Charsets;
+import com.google.common.base.Splitter;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Maps;
+import com.google.common.io.CharStreams;
+import com.google.common.io.Resources;
+import org.apache.mahout.classifier.AbstractVectorClassifier;
+import org.apache.mahout.classifier.OnlineLearner;
+import org.apache.mahout.common.MahoutTestCase;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.math.DenseMatrix;
+import org.apache.mahout.math.DenseVector;
+import org.apache.mahout.math.Matrix;
+import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.function.Functions;
+
+import java.io.IOException;
+import java.io.InputStreamReader;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public abstract class OnlineBaseTest extends MahoutTestCase {
+
+ private Matrix input;
+
+ protected Matrix getInput() {
+ return input;
+ }
+
+ protected Vector readStandardData() throws IOException {
+ // 60 test samples. First column is constant. Second and third are
normally distributed from
+ // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The
first 30 rows have a
+ // target variable of 0, the last 30 a target of 1. The remaining columns
are are random noise.
+ input = readCsv("sgd.csv");
+
+ // regenerate the target variable
+ Vector target = new DenseVector(60);
+ target.assign(0);
+ target.viewPart(30, 30).assign(1);
+ return target;
+ }
+
+ protected static void train(Matrix input, Vector target, OnlineLearner lr) {
+ RandomUtils.useTestSeed();
+ Random gen = RandomUtils.getRandom();
+
+ // train on samples in random order (but only one pass)
+ for (int row : permute(gen, 60)) {
+ lr.train((int) target.get(row), input.getRow(row));
+ }
+ lr.close();
+ }
+
+ protected static void test(Matrix input, Vector target,
AbstractVectorClassifier lr,
+ double expected_mean_error, double
expected_absolute_error) {
+ // now test the accuracy
+ Matrix tmp = lr.classify(input);
+ // mean(abs(tmp - target))
+ double meanAbsoluteError =
tmp.getColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
+
+ // max(abs(tmp - target)
+ double maxAbsoluteError =
tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
+
+ System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError,
maxAbsoluteError);
+ assertEquals(0, meanAbsoluteError , expected_mean_error);
+ assertEquals(0, maxAbsoluteError, expected_absolute_error);
+
+ // convenience methods should give the same results
+ Vector v = lr.classifyScalar(input);
+ assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
+ v = lr.classifyFull(input).getColumn(1);
+ assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
+ }
+
+ /**
+ * Permute the integers from 0 ... max-1
+ *
+ * @param gen The random number generator to use.
+ * @param max The number of integers to permute
+ * @return An array of jumbled integer values
+ */
+ protected static int[] permute(Random gen, int max) {
+ int[] permutation = new int[max];
+ permutation[0] = 0;
+ for (int i = 1; i < max; i++) {
+ int n = gen.nextInt(i + 1);
+ if (n == i) {
+ permutation[i] = i;
+ } else {
+ permutation[i] = permutation[n];
+ permutation[n] = i;
+ }
+ }
+ return permutation;
+ }
+
+
+ /**
+ * Reads a file containing CSV data. This isn't implemented quite the way
you might like for a
+ * real program, but does the job for reading test data. Most notably, it
will only read numbers,
+ * not quoted strings.
+ *
+ * @param resourceName Where to get the data.
+ * @return A matrix of the results.
+ * @throws IOException If there is an error reading the data
+ */
+ protected static Matrix readCsv(String resourceName) throws IOException {
+ Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));
+
+ Readable isr = new
InputStreamReader(Resources.getResource(resourceName).openStream(),
Charsets.UTF_8);
+ List<String> data = CharStreams.readLines(isr);
+ String first = data.get(0);
+ data = data.subList(1, data.size());
+
+ List<String> values = Lists.newArrayList(onCommas.split(first));
+ Matrix r = new DenseMatrix(data.size(), values.size());
+
+ int column = 0;
+ Map<String, Integer> labels = Maps.newHashMap();
+ for (String value : values) {
+ labels.put(value, column);
+ column++;
+ }
+ r.setColumnLabelBindings(labels);
+
+ int row = 0;
+ for (String line : data) {
+ column = 0;
+ values = Lists.newArrayList(onCommas.split(line));
+ for (String value : values) {
+ r.set(row, column, Double.parseDouble(value));
+ column++;
+ }
+ row++;
+ }
+
+ return r;
+ }
+}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=1131476&r1=1131475&r2=1131476&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java
Sat Jun 4 19:37:27 2011
@@ -17,33 +17,16 @@
package org.apache.mahout.classifier.sgd;
-import com.google.common.base.CharMatcher;
-import com.google.common.base.Charsets;
-import com.google.common.base.Splitter;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
-import com.google.common.io.CharStreams;
-import com.google.common.io.Resources;
-import org.apache.mahout.classifier.AbstractVectorClassifier;
-import org.apache.mahout.classifier.OnlineLearner;
-import org.apache.mahout.common.MahoutTestCase;
import org.apache.mahout.common.RandomUtils;
-import org.apache.mahout.math.DenseMatrix;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.Vector;
-import org.apache.mahout.math.function.Functions;
import org.junit.Test;
import java.io.IOException;
-import java.io.InputStreamReader;
-import java.util.List;
-import java.util.Map;
import java.util.Random;
-public final class OnlineLogisticRegressionTest extends MahoutTestCase {
-
- private Matrix input;
+public final class OnlineLogisticRegressionTest extends OnlineBaseTest {
/**
* The CrossFoldLearner is probably the best learner to use for new
applications.
@@ -58,10 +41,10 @@ public final class OnlineLogisticRegress
.learningRate(50);
- train(input, target, lr);
+ train(getInput(), target, lr);
System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood());
- test(input, target, lr);
+ test(getInput(), target, lr, 0.05, 0.3);
}
@@ -162,116 +145,8 @@ public final class OnlineLogisticRegress
.lambda(1 * 1.0e-3)
.learningRate(50);
- train(input, target, lr);
- test(input, target, lr);
- }
-
- private Vector readStandardData() throws IOException {
- // 60 test samples. First column is constant. Second and third are
normally distributed from
- // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The
first 30 rows have a
- // target variable of 0, the last 30 a target of 1. The remaining columns
are are random noise.
- input = readCsv("sgd.csv");
-
- // regenerate the target variable
- Vector target = new DenseVector(60);
- target.assign(0);
- target.viewPart(30, 30).assign(1);
- return target;
- }
-
- private static void train(Matrix input, Vector target, OnlineLearner lr) {
- RandomUtils.useTestSeed();
- Random gen = RandomUtils.getRandom();
-
- // train on samples in random order (but only one pass)
- for (int row : permute(gen, 60)) {
- lr.train((int) target.get(row), input.getRow(row));
- }
- lr.close();
- }
-
- private static void test(Matrix input, Vector target,
AbstractVectorClassifier lr) {
- // now test the accuracy
- Matrix tmp = lr.classify(input);
- // mean(abs(tmp - target))
- double meanAbsoluteError =
tmp.getColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60;
-
- // max(abs(tmp - target)
- double maxAbsoluteError =
tmp.getColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS);
-
- System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError,
maxAbsoluteError);
- assertEquals(0, meanAbsoluteError , 0.05);
- assertEquals(0, maxAbsoluteError, 0.3);
-
- // convenience methods should give the same results
- Vector v = lr.classifyScalar(input);
- assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-5);
- v = lr.classifyFull(input).getColumn(1);
- assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1.0e-4);
+ train(getInput(), target, lr);
+ test(getInput(), target, lr, 0.05, 0.3);
}
- /**
- * Permute the integers from 0 ... max-1
- *
- * @param gen The random number generator to use.
- * @param max The number of integers to permute
- * @return An array of jumbled integer values
- */
- private static int[] permute(Random gen, int max) {
- int[] permutation = new int[max];
- permutation[0] = 0;
- for (int i = 1; i < max; i++) {
- int n = gen.nextInt(i + 1);
- if (n == i) {
- permutation[i] = i;
- } else {
- permutation[i] = permutation[n];
- permutation[n] = i;
- }
- }
- return permutation;
- }
-
-
- /**
- * Reads a file containing CSV data. This isn't implemented quite the way
you might like for a
- * real program, but does the job for reading test data. Most notably, it
will only read numbers,
- * not quoted strings.
- *
- * @param resourceName Where to get the data.
- * @return A matrix of the results.
- * @throws IOException If there is an error reading the data
- */
- private static Matrix readCsv(String resourceName) throws IOException {
- Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \""));
-
- Readable isr = new
InputStreamReader(Resources.getResource(resourceName).openStream(),
Charsets.UTF_8);
- List<String> data = CharStreams.readLines(isr);
- String first = data.get(0);
- data = data.subList(1, data.size());
-
- List<String> values = Lists.newArrayList(onCommas.split(first));
- Matrix r = new DenseMatrix(data.size(), values.size());
-
- int column = 0;
- Map<String, Integer> labels = Maps.newHashMap();
- for (String value : values) {
- labels.put(value, column);
- column++;
- }
- r.setColumnLabelBindings(labels);
-
- int row = 0;
- for (String line : data) {
- column = 0;
- values = Lists.newArrayList(onCommas.split(line));
- for (String value : values) {
- r.set(row, column, Double.parseDouble(value));
- column++;
- }
- row++;
- }
-
- return r;
- }
-}
+}
\ No newline at end of file
Added:
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java?rev=1131476&view=auto
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
(added)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java
Sat Jun 4 19:37:27 2011
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.classifier.sgd;
+
+import org.apache.mahout.math.Vector;
+import org.junit.Test;
+
+import java.io.IOException;
+
+public final class PassiveAggressiveTest extends OnlineBaseTest {
+
+ @Test
+ public void testPassiveAggressive() throws IOException {
+ Vector target = readStandardData();
+ PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1);
+ train(getInput(), target, pa);
+ test(getInput(), target, pa, 0.1, 0.3);
+ }
+
+}