Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/ContinuousValueEncoderTest.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class ContinuousValueEncoderTest { + @Test + public void testAddToVector() { + FeatureVectorEncoder enc = new ContinuousValueEncoder("foo"); + Vector v1 = new DenseVector(20); + enc.addToVector("-123", v1); + assertEquals(-123, v1.minValue(), 0); + assertEquals(0, v1.maxValue(), 0); + assertEquals(123, v1.norm(1), 0); + + v1 = new DenseVector(20); + enc.addToVector("123", v1); + assertEquals(123, v1.maxValue(), 0); + assertEquals(0, v1.minValue(), 0); + assertEquals(123, v1.norm(1), 0); + + Vector v2 = new DenseVector(20); + enc.setProbes(2); + enc.addToVector("123", v2); + assertEquals(123, v2.maxValue(), 0); + assertEquals(2 * 123, v2.norm(1), 0); + + v1 = v2.minus(v1); + assertEquals(123, v1.maxValue(), 0); + assertEquals(123, v1.norm(1), 0); + + Vector v3 = new DenseVector(20); + enc.setProbes(2); + enc.addToVector("100", v3); + v1 = v2.minus(v3); + assertEquals(23, v1.maxValue(), 0); + assertEquals(2 * 23, v1.norm(1), 0); + + enc.addToVector("7", v1); + assertEquals(30, v1.maxValue(), 0); + assertEquals(2 * 30, v1.norm(1), 0); + assertEquals(30, v1.get(10), 0); + assertEquals(30, v1.get(18), 0); + + try { + enc.addToVector("foobar", v1); + fail("Should have noticed back numeric format"); + } catch (NumberFormatException e) { + assertEquals("For input string: \"foobar\"", e.getMessage()); + } + } + + @Test + public void testAsString() { + ContinuousValueEncoder enc = new ContinuousValueEncoder("foo"); + assertEquals("foo:123", enc.asString("123")); + } + +}
Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.ImmutableMap; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class CsvRecordFactoryTest { + @Test + public void testAddToVector() { + CsvRecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t")); + csv.firstLine("z,x1,y,x2,x3,q"); + csv.maxTargetValue(2); + + Vector v = new DenseVector(2000); + int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v); + assertEquals(0, t); + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(3.1, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(8.0, v.norm(1), 0); + assertEquals(1.0, v.maxValue(), 0); + + v.assign(0); + t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v); + assertEquals(1, t); + + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(5.3, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(12.0, v.norm(1), 0); + assertEquals(2, v.maxValue(), 0); + + v.assign(0); + t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v); + assertEquals(1, t); + + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(5.3, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(12.0, v.norm(1), 0); + assertEquals(2, v.maxValue(), 0); + } + + @Test + public void testDictionaryOrder() { + Dictionary dict = new Dictionary(); + + dict.intern("a"); + dict.intern("d"); + dict.intern("c"); + dict.intern("b"); + dict.intern("qrz"); + + Assert.assertEquals("[a, d, c, b, qrz]", dict.values().toString()); + + Dictionary dict2 = Dictionary.fromList(dict.values()); + Assert.assertEquals("[a, d, c, b, qrz]", dict2.values().toString()); + + } +} Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,273 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Splitter; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.CharStreams; +import com.google.common.io.Resources; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; +import org.junit.Test; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.List; +import java.util.Map; +import java.util.Random; + +import static org.junit.Assert.assertEquals; + +public class OnlineLogisticRegressionTest { + private Matrix input; + + /** + * The CrossFoldLearner is probably the best learner to use for new applications. + * @throws IOException If test resources aren't readable. + */ + @Test + public void crossValidation() throws IOException { + Vector target = readStandardData(); + + CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1()) + .lambda(1 * 1e-3) + .learningRate(50); + + + train(input, target, lr); + + System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood()); + test(input, target, lr); + + } + + @Test + public void crossValidatedAuc() throws IOException { + RandomUtils.useTestSeed(); + Random gen = RandomUtils.getRandom(); + + Matrix data = readCsv("cancer.csv"); + CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1()) + .stepOffset(10) + .decayExponent(0.7) + .lambda(1 * 1e-3) + .learningRate(5); + int k = 0; + int[] ordering = permute(gen, data.numRows()); + for (int epoch = 0; epoch < 100; epoch++) { + for (int row : ordering) { + lr.train(row, (int) data.get(row, 9), data.viewRow(row)); + System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc()); + } + } + } + + /** + * Verifies that a classifier with known coefficients does the right thing. + */ + @Test + public void testClassify() { + OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1)); + // set up some internal coefficients as if we had learned them + lr.setBeta(0, 0, -1); + lr.setBeta(1, 0, -2); + + // zero vector gives no information. All classes are equal. + Vector v = lr.classify(new DenseVector(new double[]{0, 0})); + assertEquals(1 / 3.0, v.get(0), 1e-8); + assertEquals(1 / 3.0, v.get(1), 1e-8); + + v = lr.classifyFull(new DenseVector(new double[]{0, 0})); + assertEquals(1.0, v.zSum(), 1e-8); + assertEquals(1 / 3.0, v.get(0), 1e-8); + assertEquals(1 / 3.0, v.get(1), 1e-8); + assertEquals(1 / 3.0, v.get(2), 1e-8); + + // weights for second vector component are still zero so all classifications are equally likely + v = lr.classify(new DenseVector(new double[]{0, 1})); + assertEquals(1 / 3.0, v.get(0), 1e-3); + assertEquals(1 / 3.0, v.get(1), 1e-3); + + v = lr.classifyFull(new DenseVector(new double[]{0, 1})); + assertEquals(1.0, v.zSum(), 1e-8); + assertEquals(1 / 3.0, v.get(0), 1e-3); + assertEquals(1 / 3.0, v.get(1), 1e-3); + assertEquals(1 / 3.0, v.get(2), 1e-3); + + // but the weights on the first component are non-zero + v = lr.classify(new DenseVector(new double[]{1, 0})); + assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1e-8); + assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1e-8); + + v = lr.classifyFull(new DenseVector(new double[]{1, 0})); + assertEquals(1.0, v.zSum(), 1e-8); + assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1e-8); + assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1e-8); + assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1e-8); + + lr.setBeta(0, 1, 1); + + v = lr.classifyFull(new DenseVector(new double[]{1, 1})); + assertEquals(1.0, v.zSum(), 1e-8); + assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1e-3); + assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1e-3); + assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1e-3); + + lr.setBeta(1, 1, 3); + + v = lr.classifyFull(new DenseVector(new double[]{1, 1})); + assertEquals(1.0, v.zSum(), 1e-8); + assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1e-8); + assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1e-8); + assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1e-8); + } + + @Test + public void testTrain() throws IOException { + Vector target = readStandardData(); + + + // lambda here needs to be relatively small to avoid swamping the actual signal, but can be + // larger than usual because the data are dense. The learning rate doesn't matter too much + // for this example, but should generally be < 1 + // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n + OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) + .lambda(1 * 1e-3) + .learningRate(50); + + train(input, target, lr); + test(input, target, lr); + } + + private Vector readStandardData() throws IOException { + // 60 test samples. First column is constant. Second and third are normally distributed from + // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a + // target variable of 0, the last 30 a target of 1. The remaining columns are are random noise. + input = readCsv("sgd.csv"); + + // regenerate the target variable + Vector target = new DenseVector(60); + target.assign(0); + target.viewPart(30, 30).assign(1); + return target; + } + + private void train(Matrix input, Vector target, OnlineLearner lr) { + RandomUtils.useTestSeed(); + Random gen = RandomUtils.getRandom(); + + // train on samples in random order (but only one pass) + for (int row : permute(gen, 60)) { + lr.train((int) target.get(row), input.getRow(row)); + } + lr.close(); + } + + private void test(Matrix input, Vector target, AbstractVectorClassifier lr) { + // now test the accuracy + Matrix tmp = lr.classify(input); + // mean(abs(tmp - target)) + double meanAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.plus, Functions.abs) / 60; + + // max(abs(tmp - target) + double maxAbsoluteError = tmp.getColumn(0).minus(target).aggregate(Functions.max, Functions.abs); + + System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError); + assertEquals(0, meanAbsoluteError , 0.05); + assertEquals(0, maxAbsoluteError, 0.3); + + // convenience methods should give the same results + Vector v = lr.classifyScalar(input); + assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1e-5); + v = lr.classifyFull(input).getColumn(1); + assertEquals(0, v.minus(tmp.getColumn(0)).norm(1), 1e-4); + } + + /** + * Permute the integers from 0 ... max-1 + * + * @param gen The random number generator to use. + * @param max The number of integers to permute + * @return An array of jumbled integer values + */ + private int[] permute(Random gen, int max) { + int[] permutation = new int[max]; + permutation[0] = 0; + for (int i = 1; i < max; i++) { + int n = gen.nextInt(i + 1); + if (n != i) { + permutation[i] = permutation[n]; + permutation[n] = i; + } else { + permutation[i] = i; + } + } + return permutation; + } + + + /** + * Reads a file containing CSV data. This isn't implemented quite the way you might like for a + * real program, but does the job for reading test data. Most notably, it will only read numbers, + * not quoted strings. + * + * @param resourceName Where to get the data. + * @return A matrix of the results. + * @throws java.io.IOException If there is an error reading the data + */ + private Matrix readCsv(String resourceName) throws IOException { + Splitter onCommas = Splitter.on(",").trimResults(CharMatcher.anyOf(" \"")); + + InputStreamReader isr = new InputStreamReader(Resources.getResource(resourceName).openStream()); + List<String> data = CharStreams.readLines(isr); + String first = data.get(0); + data = data.subList(1, data.size()); + + List<String> values = Lists.newArrayList(onCommas.split(first)); + Matrix r = new DenseMatrix(data.size(), values.size()); + + int column = 0; + Map<String, Integer> labels = Maps.newHashMap(); + for (String value : values) { + labels.put(value, column); + column++; + } + r.setColumnLabelBindings(labels); + + int row = 0; + for (String line : data) { + column = 0; + values = Lists.newArrayList(onCommas.split(line)); + for (String value : values) { + r.set(row, column, Double.parseDouble(value)); + column++; + } + row++; + } + + return r; + } +} Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/TextValueEncoderTest.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.ImmutableMap; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class TextValueEncoderTest { + @Test + public void testAddToVector() { + TextValueEncoder enc = new TextValueEncoder("text"); + Vector v1 = new DenseVector(200); + enc.addToVector("test1 and more", v1); + // should set 6 distinct locations to 1 + assertEquals(6.0, v1.norm(1), 0); + assertEquals(1.0, v1.maxValue(), 0); + + // now some fancy weighting + StaticWordValueEncoder w = new StaticWordValueEncoder("text"); + w.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5)); + enc.setWordEncoder(w); + + // should set 6 locations to something + Vector v2 = new DenseVector(200); + enc.addToVector("test1 and more", v2); + + // this should set the same 6 locations to the same values + Vector v3 = new DenseVector(200); + w.addToVector("test1", v3); + w.addToVector("and", v3); + w.addToVector("more", v3); + + assertEquals(0, v3.minus(v2).norm(1), 0); + + // moreover, the locations set in the unweighted case should be the same as in the weighted case + assertEquals(v3.zSum(), v3.dot(v1), 0); + } + + @Test + public void testAsString() { + TextValueEncoder enc = new TextValueEncoder("text"); + assertEquals("[text:test1:1.0000, text:and:1.0000, text:more:1.0000]", enc.asString("test1 and more")); + } +} Added: mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java (added) +++ mahout/trunk/core/src/test/java/org/apache/mahout/classifier/sgd/WordLikeValueEncoderTest.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,91 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import java.util.Iterator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +public class WordLikeValueEncoderTest { + @Test + public void testAddToVector() { + FeatureVectorEncoder enc = new StaticWordValueEncoder("word"); + Vector v = new DenseVector(200); + enc.addToVector("word1", v); + enc.addToVector("word2", v); + Iterator<Vector.Element> i = v.iterateNonZero(); + Iterator<Integer> j = ImmutableList.of(7, 118, 119, 199).iterator(); + while (i.hasNext()) { + Vector.Element element = i.next(); + assertEquals(j.next().intValue(), element.index()); + assertEquals(1, element.get(), 0); + } + assertFalse(j.hasNext()); + } + + @Test + public void testAsString() { + FeatureVectorEncoder enc = new StaticWordValueEncoder("word"); + assertEquals("word:w1:1.0000", enc.asString("w1")); + } + + @Test + public void testStaticWeights() { + StaticWordValueEncoder enc = new StaticWordValueEncoder("word"); + enc.setDictionary(ImmutableMap.<String, Double>of("word1", 3.0, "word2", 1.5)); + Vector v = new DenseVector(200); + enc.addToVector("word1", v); + enc.addToVector("word2", v); + enc.addToVector("word3", v); + Iterator<Vector.Element> i = v.iterateNonZero(); + Iterator<Integer> j = ImmutableList.of(7, 101, 118, 119, 152, 199).iterator(); + Iterator<Double> k = ImmutableList.of(3.0, 0.75, 1.5, 1.5, 0.75, 3.0).iterator(); + while (i.hasNext()) { + Vector.Element element = i.next(); + assertEquals(j.next().intValue(), element.index()); + assertEquals(k.next(), element.get(), 0); + } + assertFalse(j.hasNext()); + } + + @Test + public void testDynamicWeights() { + FeatureVectorEncoder enc = new AdaptiveWordValueEncoder("word"); + Vector v = new DenseVector(200); + enc.addToVector("word1", v); // weight is log(2/1.5) + enc.addToVector("word2", v); // weight is log(3.5 / 1.5) + enc.addToVector("word1", v); // weight is log(4.5 / 2.5) (but overlays on first value) + enc.addToVector("word3", v); // weight is log(6 / 1.5) + Iterator<Vector.Element> i = v.iterateNonZero(); + Iterator<Integer> j = ImmutableList.of(7, 101, 118, 119, 152, 199).iterator(); + Iterator<Double> k = ImmutableList.of(Math.log(2 / 1.5) + Math.log(4.5 / 2.5), Math.log(6 / 1.5), Math.log(3.5 / 1.5), Math.log(3.5 / 1.5), Math.log(6 / 1.5), Math.log(2 / 1.5) + Math.log(4.5 / 2.5)).iterator(); + while (i.hasNext()) { + Vector.Element element = i.next(); + assertEquals(j.next().intValue(), element.index()); + assertEquals(k.next(), element.get(), 1e-6); + } + assertFalse(j.hasNext()); + } +} Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java (added) +++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/OnlineAuc.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,100 @@ +package org.apache.mahout.math.stats; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; + +import java.util.Random; + +/** + * Computes a running estimate of AUC (see http://en.wikipedia.org/wiki/Receiver_operating_characteristic). + * <p/> + * Since AUC is normally a global property of labeled scores, it is almost always computed in a + * batch fashion. The probabilistic definition (the probability that a random element of one set + * has a higher score than a random element of another set) gives us a way to estimate this + * on-line. + */ +public class OnlineAuc { + private Random random = new Random(); + + enum ReplacementPolicy { + FIFO, FAIR, RANDOM + } + + public static final int HISTORY = 10; + + private ReplacementPolicy policy = ReplacementPolicy.FAIR; + + private Matrix scores; + private Vector averages; + + private Vector samples; + + public OnlineAuc() { + int numCategories = 2; + scores = new DenseMatrix(numCategories, HISTORY); + scores.assign(Double.NaN); + averages = new DenseVector(numCategories); + averages.assign(0.5); + samples = new DenseVector(numCategories); + } + + public double addSample(int category, final double score) { + int n = (int) samples.get(category); + if (n < HISTORY) { + scores.set(category, n, score); + } else { + switch (policy) { + case FIFO: + scores.set(category, n % HISTORY, score); + break; + case FAIR: + int j = random.nextInt(n + 1); + if (j < HISTORY) { + scores.set(category, j, score); + } + break; + case RANDOM: + j = random.nextInt(HISTORY); + scores.set(category, j, score); + break; + } + } + + samples.set(category, n + 1); + + if (samples.minValue() >= 1) { + // compare to previous scores for other category + Vector row = scores.viewRow(1 - category); + double m = 0; + int count = 0; + for (Vector.Element element : row) { + double v = element.get(); + if (!Double.isNaN(v)) { + count++; + double z = 0.5; + if (score > v) { + z = 1; + } else if (score < v) { + z = 0; + } + m += (z - m) / count; + } else { + break; + } + } + averages.set(category, averages.get(category) + (m - averages.get(category)) / samples.get(category)); + } + return auc(); + } + + public double auc() { + // return an unweighted average of all averages. + return 0.5 - averages.get(0) / 2 + averages.get(1) / 2; + } + + public void setPolicy(ReplacementPolicy policy) { + this.policy = policy; + } +} Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java?rev=986045&view=auto ============================================================================== --- mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java (added) +++ mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/OnlineAucTest.java Mon Aug 16 16:56:46 2010 @@ -0,0 +1,64 @@ +package org.apache.mahout.math.stats; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.Random; + +public class OnlineAucTest { + @Test + public void binaryCase() { + OnlineAuc a1 = new OnlineAuc(); + a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR); + + OnlineAuc a2 = new OnlineAuc(); + a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO); + + OnlineAuc a3 = new OnlineAuc(); + a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM); + + Random gen = new Random(1); + for (int i = 0; i < 10000; i++) { + double x = gen.nextGaussian(); + + a1.addSample(0, x); + a2.addSample(0, x); + a3.addSample(0, x); + + x = gen.nextGaussian() + 1; + + a1.addSample(1, x); + a2.addSample(1, x); + a3.addSample(1, x); + } + + a1 = new OnlineAuc(); + a1.setPolicy(OnlineAuc.ReplacementPolicy.FAIR); + + a2 = new OnlineAuc(); + a2.setPolicy(OnlineAuc.ReplacementPolicy.FIFO); + + a3 = new OnlineAuc(); + a3.setPolicy(OnlineAuc.ReplacementPolicy.RANDOM); + + gen = new Random(1); + for (int i = 0; i < 10000; i++) { + double x = gen.nextGaussian(); + + a1.addSample(1, x); + a2.addSample(1, x); + a3.addSample(1, x); + + x = gen.nextGaussian() + 1; + + a1.addSample(0, x); + a2.addSample(0, x); + a3.addSample(0, x); + } + + // reference value computed using R: mean(rnorm(1000000) < rnorm(1000000,1)) + Assert.assertEquals(1 - 0.76, a1.auc(), 0.05); + Assert.assertEquals(1 - 0.76, a2.auc(), 0.05); + Assert.assertEquals(1 - 0.76, a3.auc(), 0.05); + } +}
