http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java new file mode 100644 index 0000000..aa15410 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/df/tools/VisualizerTest.java @@ -0,0 +1,211 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.df.tools; + +import org.apache.mahout.classifier.df.DecisionForest; +import org.apache.mahout.classifier.df.builder.DecisionTreeBuilder; +import org.apache.mahout.classifier.df.data.Data; +import org.apache.mahout.classifier.df.data.DataLoader; +import org.apache.mahout.classifier.df.data.Dataset; +import org.apache.mahout.classifier.df.data.Instance; +import org.apache.mahout.classifier.df.node.CategoricalNode; +import org.apache.mahout.classifier.df.node.Leaf; +import org.apache.mahout.classifier.df.node.Node; +import org.apache.mahout.classifier.df.node.NumericalNode; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import java.text.DecimalFormat; +import java.util.List; +import java.util.Random; +import java.util.ArrayList; +@Deprecated +public final class VisualizerTest extends MahoutTestCase { + + private static final char DECIMAL_SEPARATOR = + ((DecimalFormat) DecimalFormat.getInstance()).getDecimalFormatSymbols().getDecimalSeparator(); + + private static final String[] TRAIN_DATA = {"sunny,85,85,FALSE,no", + "sunny,80,90,TRUE,no", "overcast,83,86,FALSE,yes", + "rainy,70,96,FALSE,yes", "rainy,68,80,FALSE,yes", "rainy,65,70,TRUE,no", + "overcast,64,65,TRUE,yes", "sunny,72,95,FALSE,no", + "sunny,69,70,FALSE,yes", "rainy,75,80,FALSE,yes", "sunny,75,70,TRUE,yes", + "overcast,72,90,TRUE,yes", "overcast,81,75,FALSE,yes", + "rainy,71,91,TRUE,no"}; + + private static final String[] TEST_DATA = {"rainy,70,96,TRUE,-", + "overcast,64,65,TRUE,-", "sunny,75,90,TRUE,-",}; + + private static final String[] ATTRIBUTE_NAMES = {"outlook", "temperature", + "humidity", "windy", "play"}; + + private Random randomNumberGenerator; + + private Data trainingData; + + private Data testData; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + randomNumberGenerator = RandomUtils.getRandom(1); + + Dataset dataset = DataLoader + .generateDataset("C N N C L", false, TRAIN_DATA); + + trainingData = DataLoader.loadData(dataset, TRAIN_DATA); + + testData = DataLoader.loadData(dataset, TEST_DATA); + } + + @Test + public void testTreeVisualize() throws Exception { + // build tree + DecisionTreeBuilder builder = new DecisionTreeBuilder(); + builder.setM(trainingData.getDataset().nbAttributes() - 1); + Node tree = builder.build(randomNumberGenerator, trainingData); + + String visualization = TreeVisualizer.toString(tree, trainingData.getDataset(), ATTRIBUTE_NAMES); + + assertTrue( + (String.format("\n" + + "outlook = rainy\n" + + "| windy = FALSE : yes\n" + + "| windy = TRUE : no\n" + + "outlook = sunny\n" + + "| humidity < 77%s5 : yes\n" + + "| humidity >= 77%s5 : no\n" + + "outlook = overcast : yes", DECIMAL_SEPARATOR, DECIMAL_SEPARATOR)).equals(visualization) || + (String.format("\n" + + "outlook = rainy\n" + + "| windy = TRUE : no\n" + + "| windy = FALSE : yes\n" + + "outlook = overcast : yes\n" + + "outlook = sunny\n" + + "| humidity < 77%s5 : yes\n" + + "| humidity >= 77%s5 : no", DECIMAL_SEPARATOR, DECIMAL_SEPARATOR)).equals(visualization)); + } + + @Test + public void testPredictTrace() throws Exception { + // build tree + DecisionTreeBuilder builder = new DecisionTreeBuilder(); + builder.setM(trainingData.getDataset().nbAttributes() - 1); + Node tree = builder.build(randomNumberGenerator, trainingData); + + String[] prediction = TreeVisualizer.predictTrace(tree, testData, + ATTRIBUTE_NAMES); + Assert.assertArrayEquals(new String[] { + "outlook = rainy -> windy = TRUE -> no", "outlook = overcast -> yes", + String.format("outlook = sunny -> (humidity = 90) >= 77%s5 -> no", DECIMAL_SEPARATOR)}, prediction); + } + + @Test + public void testForestVisualize() throws Exception { + // Tree + NumericalNode root = new NumericalNode(2, 90, new Leaf(0), + new CategoricalNode(0, new double[] {0, 1, 2}, new Node[] { + new NumericalNode(1, 71, new Leaf(0), new Leaf(1)), new Leaf(1), + new Leaf(0)})); + List<Node> trees = new ArrayList<>(); + trees.add(root); + + // Forest + DecisionForest forest = new DecisionForest(trees); + String visualization = ForestVisualizer.toString(forest, trainingData.getDataset(), null); + assertTrue( + ("Tree[1]:\n2 < 90 : yes\n2 >= 90\n" + + "| 0 = rainy\n" + + "| | 1 < 71 : yes\n" + + "| | 1 >= 71 : no\n" + + "| 0 = sunny : no\n" + + "| 0 = overcast : yes\n").equals(visualization) || + ("Tree[1]:\n" + + "2 < 90 : no\n" + + "2 >= 90\n" + + "| 0 = rainy\n" + + "| | 1 < 71 : no\n" + + "| | 1 >= 71 : yes\n" + + "| 0 = overcast : yes\n" + + "| 0 = sunny : no\n").equals(visualization)); + + visualization = ForestVisualizer.toString(forest, trainingData.getDataset(), ATTRIBUTE_NAMES); + assertTrue( + ("Tree[1]:\n" + + "humidity < 90 : yes\n" + + "humidity >= 90\n" + + "| outlook = rainy\n" + + "| | temperature < 71 : yes\n" + + "| | temperature >= 71 : no\n" + + "| outlook = sunny : no\n" + + "| outlook = overcast : yes\n").equals(visualization) || + ("Tree[1]:\n" + + "humidity < 90 : no\n" + + "humidity >= 90\n" + + "| outlook = rainy\n" + + "| | temperature < 71 : no\n" + + "| | temperature >= 71 : yes\n" + + "| outlook = overcast : yes\n" + + "| outlook = sunny : no\n").equals(visualization)); + } + + @Test + public void testLeafless() throws Exception { + List<Instance> instances = new ArrayList<>(); + for (int i = 0; i < trainingData.size(); i++) { + if (trainingData.get(i).get(0) != 0.0d) { + instances.add(trainingData.get(i)); + } + } + Data lessData = new Data(trainingData.getDataset(), instances); + + // build tree + DecisionTreeBuilder builder = new DecisionTreeBuilder(); + builder.setM(trainingData.getDataset().nbAttributes() - 1); + builder.setMinSplitNum(0); + builder.setComplemented(false); + Node tree = builder.build(randomNumberGenerator, lessData); + + String visualization = TreeVisualizer.toString(tree, trainingData.getDataset(), ATTRIBUTE_NAMES); + assertTrue( + (String.format("\noutlook = sunny\n" + + "| humidity < 77%s5 : yes\n" + + "| humidity >= 77%s5 : no\n" + + "outlook = overcast : yes", DECIMAL_SEPARATOR, DECIMAL_SEPARATOR)).equals(visualization) || + (String.format("\noutlook = overcast : yes\n" + + "outlook = sunny\n" + + "| humidity < 77%s5 : yes\n" + + "| humidity >= 77%s5 : no", DECIMAL_SEPARATOR, DECIMAL_SEPARATOR)).equals(visualization)); + } + + @Test + public void testEmpty() throws Exception { + Data emptyData = new Data(trainingData.getDataset()); + + // build tree + DecisionTreeBuilder builder = new DecisionTreeBuilder(); + Node tree = builder.build(randomNumberGenerator, emptyData); + + assertEquals(" : unknown", TreeVisualizer.toString(tree, trainingData.getDataset(), ATTRIBUTE_NAMES)); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java new file mode 100644 index 0000000..66fe97b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/evaluation/AucTest.java @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.evaluation; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.jet.random.Normal; +import org.junit.Test; + +import java.util.Random; + +public final class AucTest extends MahoutTestCase { + + @Test + public void testAuc() { + Auc auc = new Auc(); + Random gen = RandomUtils.getRandom(); + auc.setProbabilityScore(false); + for (int i=0; i<100000; i++) { + auc.add(0, gen.nextGaussian()); + auc.add(1, gen.nextGaussian() + 1); + } + assertEquals(0.76, auc.auc(), 0.01); + } + + @Test + public void testTies() { + Auc auc = new Auc(); + Random gen = RandomUtils.getRandom(); + auc.setProbabilityScore(false); + for (int i=0; i<100000; i++) { + auc.add(0, gen.nextGaussian()); + auc.add(1, gen.nextGaussian() + 1); + } + + // ties outside the normal range could cause index out of range + auc.add(0, 5.0); + auc.add(0, 5.0); + auc.add(0, 5.0); + auc.add(0, 5.0); + + auc.add(1, 5.0); + auc.add(1, 5.0); + auc.add(1, 5.0); + + assertEquals(0.76, auc.auc(), 0.05); + } + + @Test + public void testEntropy() { + Auc auc = new Auc(); + Random gen = RandomUtils.getRandom(); + Normal n0 = new Normal(-1, 1, gen); + Normal n1 = new Normal(1, 1, gen); + for (int i=0; i<100000; i++) { + double score = n0.nextDouble(); + double p = n1.pdf(score) / (n0.pdf(score) + n1.pdf(score)); + auc.add(0, p); + + score = n1.nextDouble(); + p = n1.pdf(score) / (n0.pdf(score) + n1.pdf(score)); + auc.add(1, p); + } + Matrix m = auc.entropy(); + assertEquals(-0.35, m.get(0, 0), 0.02); + assertEquals(-2.36, m.get(0, 1), 0.02); + assertEquals(-2.36, m.get(1, 0), 0.02); + assertEquals(-0.35, m.get(1, 1), 0.02); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java new file mode 100644 index 0000000..f658738 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/ComplementaryNaiveBayesClassifierTest.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.math.DenseVector; +import org.junit.Before; +import org.junit.Test; + +public final class ComplementaryNaiveBayesClassifierTest extends NaiveBayesTestBase { + + private ComplementaryNaiveBayesClassifier classifier; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + NaiveBayesModel model = createComplementaryNaiveBayesModel(); + classifier = new ComplementaryNaiveBayesClassifier(model); + } + + @Test + public void testNaiveBayes() throws Exception { + assertEquals(4, classifier.numCategories()); + assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 })))); + assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 })))); + assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 })))); + assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 })))); + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java new file mode 100644 index 0000000..3b83492 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesModelTest.java @@ -0,0 +1,36 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.junit.Test; + +public class NaiveBayesModelTest extends NaiveBayesTestBase { + + @Test + public void testRandomModelGeneration() { + // make sure we generate a valid random model + NaiveBayesModel standardModel = getStandardModel(); + // check whether the model is valid + standardModel.validate(); + + // same for Complementary + NaiveBayesModel complementaryModel = getComplementaryModel(); + complementaryModel.validate(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java new file mode 100644 index 0000000..abd666e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTest.java @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import java.io.File; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.naivebayes.training.TrainNaiveBayesJob; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.hadoop.MathHelper; +import org.junit.Before; +import org.junit.Test; + +public class NaiveBayesTest extends MahoutTestCase { + + private Configuration conf; + private File inputFile; + private File outputDir; + private File tempDir; + + static final Text LABEL_STOLEN = new Text("/stolen/"); + static final Text LABEL_NOT_STOLEN = new Text("/not_stolen/"); + + static final Vector.Element COLOR_RED = MathHelper.elem(0, 1); + static final Vector.Element COLOR_YELLOW = MathHelper.elem(1, 1); + static final Vector.Element TYPE_SPORTS = MathHelper.elem(2, 1); + static final Vector.Element TYPE_SUV = MathHelper.elem(3, 1); + static final Vector.Element ORIGIN_DOMESTIC = MathHelper.elem(4, 1); + static final Vector.Element ORIGIN_IMPORTED = MathHelper.elem(5, 1); + + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + conf = getConfiguration(); + + inputFile = getTestTempFile("trainingInstances.seq"); + outputDir = getTestTempDir("output"); + outputDir.delete(); + tempDir = getTestTempDir("tmp"); + + SequenceFile.Writer writer = new SequenceFile.Writer(FileSystem.get(conf), conf, + new Path(inputFile.getAbsolutePath()), Text.class, VectorWritable.class); + + try { + writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_DOMESTIC)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SPORTS, ORIGIN_IMPORTED)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_IMPORTED)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_YELLOW, TYPE_SUV, ORIGIN_DOMESTIC)); + writer.append(LABEL_NOT_STOLEN, trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_IMPORTED)); + writer.append(LABEL_STOLEN, trainingInstance(COLOR_RED, TYPE_SPORTS, ORIGIN_IMPORTED)); + } finally { + Closeables.close(writer, false); + } + } + + @Test + public void toyData() throws Exception { + TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); + trainNaiveBayes.setConf(conf); + trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--tempDir", tempDir.getAbsolutePath() }); + + NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf); + + AbstractVectorClassifier classifier = new StandardNaiveBayesClassifier(naiveBayesModel); + + assertEquals(2, classifier.numCategories()); + + Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get()); + + // should be classified as not stolen + assertTrue(prediction.get(0) < prediction.get(1)); + } + + @Test + public void toyDataComplementary() throws Exception { + TrainNaiveBayesJob trainNaiveBayes = new TrainNaiveBayesJob(); + trainNaiveBayes.setConf(conf); + trainNaiveBayes.run(new String[] { "--input", inputFile.getAbsolutePath(), "--output", outputDir.getAbsolutePath(), + "--trainComplementary", + "--tempDir", tempDir.getAbsolutePath() }); + + NaiveBayesModel naiveBayesModel = NaiveBayesModel.materialize(new Path(outputDir.getAbsolutePath()), conf); + + AbstractVectorClassifier classifier = new ComplementaryNaiveBayesClassifier(naiveBayesModel); + + assertEquals(2, classifier.numCategories()); + + Vector prediction = classifier.classifyFull(trainingInstance(COLOR_RED, TYPE_SUV, ORIGIN_DOMESTIC).get()); + + // should be classified as not stolen + assertTrue(prediction.get(0) < prediction.get(1)); + } + + static VectorWritable trainingInstance(Vector.Element... elems) { + DenseVector trainingInstance = new DenseVector(6); + for (Vector.Element elem : elems) { + trainingInstance.set(elem.index(), elem.get()); + } + return new VectorWritable(trainingInstance); + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java new file mode 100644 index 0000000..a943b7b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/NaiveBayesTestBase.java @@ -0,0 +1,135 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; + +public abstract class NaiveBayesTestBase extends MahoutTestCase { + + private NaiveBayesModel standardModel; + private NaiveBayesModel complementaryModel; + + @Override + public void setUp() throws Exception { + super.setUp(); + standardModel = createStandardNaiveBayesModel(); + standardModel.validate(); + complementaryModel = createComplementaryNaiveBayesModel(); + complementaryModel.validate(); + } + + protected NaiveBayesModel getStandardModel() { + return standardModel; + } + protected NaiveBayesModel getComplementaryModel() { + return complementaryModel; + } + + protected static double complementaryNaiveBayesThetaWeight(int label, + Matrix weightMatrix, + Vector labelSum, + Vector featureSum) { + double weight = 0.0; + double alpha = 1.0; + for (int i = 0; i < featureSum.size(); i++) { + double score = weightMatrix.get(i, label); + double lSum = labelSum.get(label); + double fSum = featureSum.get(i); + double totalSum = featureSum.zSum(); + double numerator = fSum - score + alpha; + double denominator = totalSum - lSum + featureSum.size(); + weight += Math.abs(Math.log(numerator / denominator)); + } + return weight; + } + + protected static double naiveBayesThetaWeight(int label, + Matrix weightMatrix, + Vector labelSum, + Vector featureSum) { + double weight = 0.0; + double alpha = 1.0; + for (int feature = 0; feature < featureSum.size(); feature++) { + double score = weightMatrix.get(feature, label); + double lSum = labelSum.get(label); + double numerator = score + alpha; + double denominator = lSum + featureSum.size(); + weight += Math.abs(Math.log(numerator / denominator)); + } + return weight; + } + + protected static NaiveBayesModel createStandardNaiveBayesModel() { + double[][] matrix = { + { 0.7, 0.1, 0.1, 0.3 }, + { 0.4, 0.4, 0.1, 0.1 }, + { 0.1, 0.0, 0.8, 0.1 }, + { 0.1, 0.1, 0.1, 0.7 } }; + + double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 }; + double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 }; + + DenseMatrix weightMatrix = new DenseMatrix(matrix); + DenseVector labelSum = new DenseVector(labelSumArray); + DenseVector featureSum = new DenseVector(featureSumArray); + + // now generate the model + return new NaiveBayesModel(weightMatrix, featureSum, labelSum, null, 1.0f, false); + } + + protected static NaiveBayesModel createComplementaryNaiveBayesModel() { + double[][] matrix = { + { 0.7, 0.1, 0.1, 0.3 }, + { 0.4, 0.4, 0.1, 0.1 }, + { 0.1, 0.0, 0.8, 0.1 }, + { 0.1, 0.1, 0.1, 0.7 } }; + + double[] labelSumArray = { 1.2, 1.0, 1.0, 1.0 }; + double[] featureSumArray = { 1.3, 0.6, 1.1, 1.2 }; + + DenseMatrix weightMatrix = new DenseMatrix(matrix); + DenseVector labelSum = new DenseVector(labelSumArray); + DenseVector featureSum = new DenseVector(featureSumArray); + + double[] thetaNormalizerSum = { + complementaryNaiveBayesThetaWeight(0, weightMatrix, labelSum, featureSum), + complementaryNaiveBayesThetaWeight(1, weightMatrix, labelSum, featureSum), + complementaryNaiveBayesThetaWeight(2, weightMatrix, labelSum, featureSum), + complementaryNaiveBayesThetaWeight(3, weightMatrix, labelSum, featureSum) }; + + // now generate the model + return new NaiveBayesModel(weightMatrix, featureSum, labelSum, new DenseVector(thetaNormalizerSum), 1.0f, true); + } + + protected static int maxIndex(Vector instance) { + int maxIndex = -1; + double maxScore = Integer.MIN_VALUE; + for (Element label : instance.all()) { + if (label.get() >= maxScore) { + maxIndex = label.index(); + maxScore = label.get(); + } + } + return maxIndex; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java new file mode 100644 index 0000000..a432ac9 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/StandardNaiveBayesClassifierTest.java @@ -0,0 +1,47 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes; + +import org.apache.mahout.math.DenseVector; +import org.junit.Before; +import org.junit.Test; + + +public final class StandardNaiveBayesClassifierTest extends NaiveBayesTestBase { + + private StandardNaiveBayesClassifier classifier; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + NaiveBayesModel model = createStandardNaiveBayesModel(); + classifier = new StandardNaiveBayesClassifier(model); + } + + @Test + public void testNaiveBayes() throws Exception { + assertEquals(4, classifier.numCategories()); + assertEquals(0, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 1.0, 0.0, 0.0, 0.0 })))); + assertEquals(1, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 1.0, 0.0, 0.0 })))); + assertEquals(2, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 1.0, 0.0 })))); + assertEquals(3, maxIndex(classifier.classifyFull(new DenseVector(new double[] { 0.0, 0.0, 0.0, 1.0 })))); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java new file mode 100644 index 0000000..46d861c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/IndexInstancesMapperTest.java @@ -0,0 +1,85 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Counter; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.map.OpenObjectIntHashMap; +import org.easymock.EasyMock; +import org.junit.Before; +import org.junit.Test; + +public class IndexInstancesMapperTest extends MahoutTestCase { + + private Mapper.Context ctx; + private OpenObjectIntHashMap<String> labelIndex; + private VectorWritable instance; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + + ctx = EasyMock.createMock(Mapper.Context.class); + instance = new VectorWritable(new DenseVector(new double[] { 1, 0, 1, 1, 0 })); + + labelIndex = new OpenObjectIntHashMap<>(); + labelIndex.put("bird", 0); + labelIndex.put("cat", 1); + } + + + @Test + public void index() throws Exception { + + ctx.write(new IntWritable(0), instance); + + EasyMock.replay(ctx); + + IndexInstancesMapper indexInstances = new IndexInstancesMapper(); + setField(indexInstances, "labelIndex", labelIndex); + + indexInstances.map(new Text("/bird/"), instance, ctx); + + EasyMock.verify(ctx); + } + + @Test + public void skip() throws Exception { + + Counter skippedInstances = EasyMock.createMock(Counter.class); + + EasyMock.expect(ctx.getCounter(IndexInstancesMapper.Counter.SKIPPED_INSTANCES)).andReturn(skippedInstances); + skippedInstances.increment(1); + + EasyMock.replay(ctx, skippedInstances); + + IndexInstancesMapper indexInstances = new IndexInstancesMapper(); + setField(indexInstances, "labelIndex", labelIndex); + + indexInstances.map(new Text("/fish/"), instance, ctx); + + EasyMock.verify(ctx, skippedInstances); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java new file mode 100644 index 0000000..746ae0d --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/ThetaMapperTest.java @@ -0,0 +1,61 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.easymock.EasyMock; +import org.junit.Test; + +public class ThetaMapperTest extends MahoutTestCase { + + @Test + public void standard() throws Exception { + + Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class); + ComplementaryThetaTrainer trainer = EasyMock.createMock(ComplementaryThetaTrainer.class); + + Vector instance1 = new DenseVector(new double[] { 1, 2, 3 }); + Vector instance2 = new DenseVector(new double[] { 4, 5, 6 }); + + Vector perLabelThetaNormalizer = new DenseVector(new double[] { 7, 8 }); + + ThetaMapper thetaMapper = new ThetaMapper(); + setField(thetaMapper, "trainer", trainer); + + trainer.train(0, instance1); + trainer.train(1, instance2); + EasyMock.expect(trainer.retrievePerLabelThetaNormalizer()).andReturn(perLabelThetaNormalizer); + ctx.write(new Text(TrainNaiveBayesJob.LABEL_THETA_NORMALIZER), new VectorWritable(perLabelThetaNormalizer)); + + EasyMock.replay(ctx, trainer); + + thetaMapper.map(new IntWritable(0), new VectorWritable(instance1), ctx); + thetaMapper.map(new IntWritable(1), new VectorWritable(instance2), ctx); + thetaMapper.cleanup(ctx); + + EasyMock.verify(ctx, trainer); + } + + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java new file mode 100644 index 0000000..af0b464 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/naivebayes/training/WeightsMapperTest.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.naivebayes.training; + +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.easymock.EasyMock; +import org.junit.Test; + +public class WeightsMapperTest extends MahoutTestCase { + + @Test + public void scores() throws Exception { + + Mapper.Context ctx = EasyMock.createMock(Mapper.Context.class); + Vector instance1 = new DenseVector(new double[] { 1, 0, 0.5, 0.5, 0 }); + Vector instance2 = new DenseVector(new double[] { 0, 0.5, 0, 0, 0 }); + Vector instance3 = new DenseVector(new double[] { 1, 0.5, 1, 1.5, 1 }); + + Vector weightsPerLabel = new DenseVector(new double[] { 0, 0 }); + + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_FEATURE), + new VectorWritable(new DenseVector(new double[] { 2, 1, 1.5, 2, 1 }))); + ctx.write(new Text(TrainNaiveBayesJob.WEIGHTS_PER_LABEL), + new VectorWritable(new DenseVector(new double[] { 2.5, 5 }))); + + EasyMock.replay(ctx); + + WeightsMapper weights = new WeightsMapper(); + setField(weights, "weightsPerLabel", weightsPerLabel); + + weights.map(new IntWritable(0), new VectorWritable(instance1), ctx); + weights.map(new IntWritable(0), new VectorWritable(instance2), ctx); + weights.map(new IntWritable(1), new VectorWritable(instance3), ctx); + + weights.cleanup(ctx); + + EasyMock.verify(ctx); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java new file mode 100644 index 0000000..ade25b8 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMAlgorithmsTest.java @@ -0,0 +1,164 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.junit.Test; + +public class HMMAlgorithmsTest extends HMMTestBase { + + /** + * Test the forward algorithm by comparing the alpha values with the values + * obtained from HMM R model. We test the test observation sequence "O1" "O0" + * "O2" "O2" "O0" "O0" "O1" by comparing the generated alpha values to the + * R-generated "reference". + */ + @Test + public void testForwardAlgorithm() { + // intialize the expected alpha values + double[][] alphaExpectedA = { + {0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04, + 4.614927e-05}, + {0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04, + 1.721505e-05}, + {0.32, 0.0262, 0.002542, 0.00038026, 0.0001360234, 3.002345e-05, + 9.659608e-05}, + {0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00, + 2.428986e-05},}; + // fetch the alpha matrix using the forward algorithm + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false); + // first do some basic checking + assertNotNull(alpha); + assertEquals(4, alpha.numCols()); + assertEquals(7, alpha.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(alphaExpectedA[i][j], alpha.get(j, i), EPSILON); + } + } + } + + @Test + public void testLogScaledForwardAlgorithm() { + // intialize the expected alpha values + double[][] alphaExpectedA = { + {0.02, 0.0392, 0.002438, 0.00035456, 0.0011554672, 7.158497e-04, + 4.614927e-05}, + {0.01, 0.0054, 0.001824, 0.00069486, 0.0007586904, 2.514137e-04, + 1.721505e-05}, + {0.32, 0.0262, 0.002542, 0.00038026, 0.0001360234, 3.002345e-05, + 9.659608e-05}, + {0.03, 0.0000, 0.013428, 0.00951084, 0.0000000000, 0.000000e+00, + 2.428986e-05},}; + // fetch the alpha matrix using the forward algorithm + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true); + // first do some basic checking + assertNotNull(alpha); + assertEquals(4, alpha.numCols()); + assertEquals(7, alpha.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(Math.log(alphaExpectedA[i][j]), alpha.get(j, i), EPSILON); + } + } + } + + /** + * Test the backward algorithm by comparing the beta values with the values + * obtained from HMM R model. We test the following observation sequence "O1" + * "O0" "O2" "O2" "O0" "O0" "O1" by comparing the generated beta values to the + * R-generated "reference". + */ + @Test + public void testBackwardAlgorithm() { + // intialize the expected beta values + double[][] betaExpectedA = { + {0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1}, + {0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1}, + {0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1}, + {0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}}; + // fetch the beta matrix using the backward algorithm + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false); + // first do some basic checking + assertNotNull(beta); + assertEquals(4, beta.numCols()); + assertEquals(7, beta.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(betaExpectedA[i][j], beta.get(j, i), EPSILON); + } + } + } + + @Test + public void testLogScaledBackwardAlgorithm() { + // intialize the expected beta values + double[][] betaExpectedA = { + {0.0015730559, 0.003543656, 0.00738264, 0.040692, 0.0848, 0.17, 1}, + {0.0017191865, 0.002386795, 0.00923652, 0.052232, 0.1018, 0.17, 1}, + {0.0003825772, 0.001238558, 0.00259464, 0.012096, 0.0664, 0.66, 1}, + {0.0004390858, 0.007076994, 0.01063512, 0.013556, 0.0304, 0.17, 1}}; + // fetch the beta matrix using the backward algorithm + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true); + // first do some basic checking + assertNotNull(beta); + assertEquals(4, beta.numCols()); + assertEquals(7, beta.numRows()); + // now compare the resulting matrices + for (int i = 0; i < 4; ++i) { + for (int j = 0; j < 7; ++j) { + assertEquals(Math.log(betaExpectedA[i][j]), beta.get(j, i), EPSILON); + } + } + } + + @Test + public void testViterbiAlgorithm() { + // initialize the expected hidden sequence + int[] expected = {2, 0, 3, 3, 0, 0, 2}; + // fetch the viterbi generated sequence + int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), false); + // first make sure we return the correct size + assertNotNull(computed); + assertEquals(computed.length, getSequence().length); + // now check the contents + for (int i = 0; i < getSequence().length; ++i) { + assertEquals(expected[i], computed[i]); + } + } + + @Test + public void testLogScaledViterbiAlgorithm() { + // initialize the expected hidden sequence + int[] expected = {2, 0, 3, 3, 0, 0, 2}; + // fetch the viterbi generated sequence + int[] computed = HmmAlgorithms.viterbiAlgorithm(getModel(), getSequence(), true); + // first make sure we return the correct size + assertNotNull(computed); + assertEquals(computed.length, getSequence().length); + // now check the contents + for (int i = 0; i < getSequence().length; ++i) { + assertEquals(expected[i], computed[i]); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java new file mode 100644 index 0000000..3104cb1 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.junit.Test; + +public class HMMEvaluatorTest extends HMMTestBase { + + /** + * Test to make sure the computed model likelihood ist valid. Included tests + * are: a) forwad == backward likelihood b) model likelihood for test seqeunce + * is the expected one from R reference + */ + @Test + public void testModelLikelihood() { + // compute alpha and beta values + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false); + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false); + // now test whether forward == backward likelihood + double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, false); + double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(), + beta, false); + assertEquals(forwardLikelihood, backwardLikelihood, EPSILON); + // also make sure that the likelihood matches the expected one + assertEquals(1.8425e-4, forwardLikelihood, EPSILON); + } + + /** + * Test to make sure the computed model likelihood ist valid. Included tests + * are: a) forwad == backward likelihood b) model likelihood for test seqeunce + * is the expected one from R reference + */ + @Test + public void testScaledModelLikelihood() { + // compute alpha and beta values + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true); + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true); + // now test whether forward == backward likelihood + double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, true); + double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(), + beta, true); + assertEquals(forwardLikelihood, backwardLikelihood, EPSILON); + // also make sure that the likelihood matches the expected one + assertEquals(1.8425e-4, forwardLikelihood, EPSILON); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java new file mode 100644 index 0000000..3260f51 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.junit.Test; + +public class HMMModelTest extends HMMTestBase { + + @Test + public void testRandomModelGeneration() { + // make sure we generate a valid random model + HmmModel model = new HmmModel(10, 20); + // check whether the model is valid + HmmUtils.validate(model); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java new file mode 100644 index 0000000..90f1cd8 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; + +public class HMMTestBase extends MahoutTestCase { + + private HmmModel model; + private final int[] sequence = {1, 0, 2, 2, 0, 0, 1}; + + /** + * We initialize a new HMM model using the following parameters # hidden + * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") # + * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1 + * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2 + * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial + * probabilities H0 0.2 + * <p/> + * H1 0.1 H2 0.4 H3 0.3 + * <p/> + * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0" + * "O1" + */ + + @Override + public void setUp() throws Exception { + super.setUp(); + // intialize the hidden/output state names + String[] hiddenNames = {"H0", "H1", "H2", "H3"}; + String[] outputNames = {"O0", "O1", "O2"}; + // initialize the transition matrix + double[][] transitionP = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1}, + {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}}; + // initialize the emission matrix + double[][] emissionP = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3}, + {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}}; + // initialize the initial probability vector + double[] initialP = {0.2, 0.1, 0.4, 0.3}; + // now generate the model + model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix( + emissionP), new DenseVector(initialP)); + model.registerHiddenStateNames(hiddenNames); + model.registerOutputStateNames(outputNames); + // make sure the model is valid :) + HmmUtils.validate(model); + } + + protected HmmModel getModel() { + return model; + } + + protected int[] getSequence() { + return sequence; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java new file mode 100644 index 0000000..b8f3186 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java @@ -0,0 +1,163 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public class HMMTrainerTest extends HMMTestBase { + + @Test + public void testViterbiTraining() { + // initialize the expected model parameters (from R) + // expected transition matrix + double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125}, + {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429}, + {0.5, 0.1, 0.1, 0.3}}; + // initialize the emission matrix + double[][] emissionE = {{0.882353, 0.058824, 0.058824}, + {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923}, + {0.111111, 0.111111, 0.777778}}; + + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, false); + + // now check whether the model matches our expectations + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], EPSILON); + } + + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], EPSILON); + } + } + + } + + @Test + public void testScaledViterbiTraining() { + // initialize the expected model parameters (from R) + // expected transition matrix + double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125}, + {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429}, + {0.5, 0.1, 0.1, 0.3}}; + // initialize the emission matrix + double[][] emissionE = {{0.882353, 0.058824, 0.058824}, + {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923}, + {0.111111, 0.111111, 0.777778}}; + + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, + true); + + // now check whether the model matches our expectations + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], + EPSILON); + } + + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], + EPSILON); + } + } + + } + + @Test + public void testBaumWelchTraining() { + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + // expected values from Matlab HMM package / R HMM package + double[] initialExpected = {0, 0, 1.0, 0}; + double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683}, + {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0}, + {0.0024, 0.6657, 0, 0.3319}}; + double[][] emissionExpected = {{0.9995, 0.0004, 0.0001}, + {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}}; + + HmmModel trained = HmmTrainer.trainBaumWelch(getModel(), observed, 0.1, 10, + false); + + Vector initialProbabilities = trained.getInitialProbabilities(); + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + assertEquals(initialProbabilities.get(i), initialExpected[i], + 0.0001); + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), + transitionExpected[i][j], 0.0001); + } + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), + emissionExpected[i][j], 0.0001); + } + } + } + + @Test + public void testScaledBaumWelchTraining() { + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + // expected values from Matlab HMM package / R HMM package + double[] initialExpected = {0, 0, 1.0, 0}; + double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683}, + {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0}, + {0.0024, 0.6657, 0, 0.3319}}; + double[][] emissionExpected = {{0.9995, 0.0004, 0.0001}, + {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}}; + + HmmModel trained = HmmTrainer + .trainBaumWelch(getModel(), observed, 0.1, 10, true); + + Vector initialProbabilities = trained.getInitialProbabilities(); + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + assertEquals(initialProbabilities.get(i), initialExpected[i], + 0.0001); + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), + transitionExpected[i][j], 0.0001); + } + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), + emissionExpected[i][j], 0.0001); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java new file mode 100644 index 0000000..6c34718 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Arrays; +import java.util.List; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public class HMMUtilsTest extends HMMTestBase { + + private Matrix legal22; + private Matrix legal23; + private Matrix legal33; + private Vector legal2; + private Matrix illegal22; + + @Override + public void setUp() throws Exception { + super.setUp(); + legal22 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}}); + legal23 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6}, + {0.3, 0.3, 0.4}}); + legal33 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8}, + {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}}); + legal2 = new DenseVector(new double[]{0.4, 0.6}); + illegal22 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}}); + } + + @Test + public void testValidatorLegal() { + HmmUtils.validate(new HmmModel(legal22, legal23, legal2)); + } + + @Test + public void testValidatorDimensionError() { + try { + HmmUtils.validate(new HmmModel(legal33, legal23, legal2)); + } catch (IllegalArgumentException e) { + // success + return; + } + fail(); + } + + @Test + public void testValidatorIllegelMatrixError() { + try { + HmmUtils.validate(new HmmModel(illegal22, legal23, legal2)); + } catch (IllegalArgumentException e) { + // success + return; + } + fail(); + } + + @Test + public void testEncodeStateSequence() { + String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"}; + String[] outputSequence = {"O1", "O2", "O4", "O0"}; + // test encoding the hidden Sequence + int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays + .asList(hiddenSequence), false, -1); + int[] outputSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays + .asList(outputSequence), true, -1); + // expected state sequences + int[] hiddenSequenceExp = {1, 2, 0, 3, -1}; + int[] outputSequenceExp = {1, 2, -1, 0}; + // compare + for (int i = 0; i < hiddenSequenceEnc.length; ++i) { + assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]); + } + for (int i = 0; i < outputSequenceEnc.length; ++i) { + assertEquals(outputSequenceExp[i], outputSequenceEnc[i]); + } + } + + @Test + public void testDecodeStateSequence() { + int[] hiddenSequence = {1, 2, 0, 3, 10}; + int[] outputSequence = {1, 2, 10, 0}; + // test encoding the hidden Sequence + List<String> hiddenSequenceDec = HmmUtils.decodeStateSequence( + getModel(), hiddenSequence, false, "unknown"); + List<String> outputSequenceDec = HmmUtils.decodeStateSequence( + getModel(), outputSequence, true, "unknown"); + // expected state sequences + String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"}; + String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"}; + // compare + for (int i = 0; i < hiddenSequenceExp.length; ++i) { + assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i)); + } + for (int i = 0; i < outputSequenceExp.length; ++i) { + assertEquals(outputSequenceExp[i], outputSequenceDec.get(i)); + } + } + + @Test + public void testNormalizeModel() { + DenseVector ip = new DenseVector(new double[]{10, 20}); + DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}}); + DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}}); + HmmModel model = new HmmModel(tr, em, ip); + HmmUtils.normalizeModel(model); + // the model should be valid now + HmmUtils.validate(model); + } + + @Test + public void testTruncateModel() { + DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998}); + DenseMatrix tr = new DenseMatrix(new double[][]{ + {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001}, + {0.0001, 0.0001, 0.9998}}); + DenseMatrix em = new DenseMatrix(new double[][]{ + {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001}, + {0.0001, 0.0001, 0.9998}}); + HmmModel model = new HmmModel(tr, em, ip); + // now truncate the model + HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01); + // first make sure this is a valid model + HmmUtils.validate(sparseModel); + // now check whether the values are as expected + Vector sparse_ip = sparseModel.getInitialProbabilities(); + Matrix sparse_tr = sparseModel.getTransitionMatrix(); + Matrix sparse_em = sparseModel.getEmissionMatrix(); + for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) { + assertEquals(i == 2 ? 1.0 : 0.0, sparse_ip.getQuick(i), EPSILON); + for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) { + if (i == j) { + assertEquals(1.0, sparse_tr.getQuick(i, j), EPSILON); + assertEquals(1.0, sparse_em.getQuick(i, j), EPSILON); + } else { + assertEquals(0.0, sparse_tr.getQuick(i, j), EPSILON); + assertEquals(0.0, sparse_em.getQuick(i, j), EPSILON); + } + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java new file mode 100644 index 0000000..7ea8cb2 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.jet.random.Exponential; +import org.junit.Test; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering; + +import java.util.Random; + +public final class AdaptiveLogisticRegressionTest extends MahoutTestCase { + + @ThreadLeakLingering(linger=1000) + @Test + public void testTrain() { + + Random gen = RandomUtils.getRandom(); + Exponential exp = new Exponential(0.5, gen); + Vector beta = new DenseVector(200); + for (Vector.Element element : beta.all()) { + int sign = 1; + if (gen.nextDouble() < 0.5) { + sign = -1; + } + element.set(sign * exp.nextDouble()); + } + + AdaptiveLogisticRegression.Wrapper cl = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1()); + cl.update(new double[]{1.0e-5, 1}); + + for (int i = 0; i < 10000; i++) { + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + cl.train(r); + if (i % 1000 == 0) { + System.out.printf("%10d %10.3f\n", i, cl.getLearner().auc()); + } + } + assertEquals(1, cl.getLearner().auc(), 0.1); + + AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(2, 200, new L1()); + adaptiveLogisticRegression.setInterval(1000); + + for (int i = 0; i < 20000; i++) { + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + adaptiveLogisticRegression.train(r.getKey(), r.getActual(), r.getInstance()); + if (i % 1000 == 0 && adaptiveLogisticRegression.getBest() != null) { + System.out.printf("%10d %10.4f %10.8f %.3f\n", + i, adaptiveLogisticRegression.auc(), + Math.log10(adaptiveLogisticRegression.getBest().getMappedParams()[0]), adaptiveLogisticRegression.getBest().getMappedParams()[1]); + } + } + assertEquals(1, adaptiveLogisticRegression.auc(), 0.1); + adaptiveLogisticRegression.close(); + } + + private static AdaptiveLogisticRegression.TrainingExample getExample(int i, Random gen, Vector beta) { + Vector data = new DenseVector(200); + + for (Vector.Element element : data.all()) { + element.set(gen.nextDouble() < 0.3 ? 1 : 0); + } + + double p = 1 / (1 + Math.exp(1.5 - data.dot(beta))); + int target = 0; + if (gen.nextDouble() < p) { + target = 1; + } + return new AdaptiveLogisticRegression.TrainingExample(i, null, target, data); + } + + @Test + public void copyLearnsAsExpected() { + Random gen = RandomUtils.getRandom(); + Exponential exp = new Exponential(0.5, gen); + Vector beta = new DenseVector(200); + for (Vector.Element element : beta.all()) { + int sign = 1; + if (gen.nextDouble() < 0.5) { + sign = -1; + } + element.set(sign * exp.nextDouble()); + } + + // train one copy of a wrapped learner + AdaptiveLogisticRegression.Wrapper w = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1()); + for (int i = 0; i < 3000; i++) { + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + w.train(r); + if (i % 1000 == 0) { + System.out.printf("%10d %.3f\n", i, w.getLearner().auc()); + } + } + System.out.printf("%10d %.3f\n", 3000, w.getLearner().auc()); + double auc1 = w.getLearner().auc(); + + // then switch to a copy of that learner ... progress should continue + AdaptiveLogisticRegression.Wrapper w2 = w.copy(); + + for (int i = 0; i < 5000; i++) { + if (i % 1000 == 0) { + if (i == 0) { + assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001); + } + if (i == 1000) { + double auc2 = w2.getLearner().auc(); + assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1); + assertTrue("AUC should improve quickly on copy", auc1 < auc2); + } + System.out.printf("%10d %.3f\n", i, w2.getLearner().auc()); + } + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + w2.train(r); + } + assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5); + + // this improvement is really quite lenient + assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05); + + // make sure that the copy didn't lose anything + assertEquals(auc1, w.getLearner().auc(), 0); + } + + @Test + public void stepSize() { + assertEquals(500, AdaptiveLogisticRegression.stepSize(15000, 2)); + assertEquals(2000, AdaptiveLogisticRegression.stepSize(15000, 2.6)); + assertEquals(5000, AdaptiveLogisticRegression.stepSize(24000, 2.6)); + assertEquals(10000, AdaptiveLogisticRegression.stepSize(15000, 3)); + } + + @Test + @ThreadLeakLingering(linger = 1000) + public void constantStep() { + AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1()); + lr.setInterval(5000); + assertEquals(20000, lr.nextStep(15000)); + assertEquals(20000, lr.nextStep(15001)); + assertEquals(20000, lr.nextStep(16500)); + assertEquals(20000, lr.nextStep(19999)); + lr.close(); + } + + + @Test + @ThreadLeakLingering(linger = 1000) + public void growingStep() { + AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1()); + lr.setInterval(2000, 10000); + + // start with minimum step size + for (int i = 2000; i < 20000; i+=2000) { + assertEquals(i + 2000, lr.nextStep(i)); + } + + // then level up a bit + for (int i = 20000; i < 50000; i += 5000) { + assertEquals(i + 5000, lr.nextStep(i)); + } + + // and more, but we top out with this step size + for (int i = 50000; i < 500000; i += 10000) { + assertEquals(i + 10000, lr.nextStep(i)); + } + lr.close(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java new file mode 100644 index 0000000..6ee0ddf --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.ImmutableMap; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.Dictionary; +import org.junit.Test; + +public final class CsvRecordFactoryTest extends MahoutTestCase { + + @Test + public void testAddToVector() { + RecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t")); + csv.firstLine("z,x1,y,x2,x3,q"); + csv.maxTargetValue(2); + + Vector v = new DenseVector(2000); + int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v); + assertEquals(0, t); + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(3.1, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(8.0, v.norm(1), 0); + assertEquals(1.0, v.maxValue(), 0); + + v.assign(0); + t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v); + assertEquals(1, t); + + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(5.3, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(10.339850002884626, v.norm(1), 1.0e-6); + assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6); + + v.assign(0); + t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v); + assertEquals(1, t); + + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(5.3, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(10.339850002884626, v.norm(1), 1.0e-6); + assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6); + } + + @Test + public void testDictionaryOrder() { + Dictionary dict = new Dictionary(); + + dict.intern("a"); + dict.intern("d"); + dict.intern("c"); + dict.intern("b"); + dict.intern("qrz"); + + assertEquals("[a, d, c, b, qrz]", dict.values().toString()); + + Dictionary dict2 = Dictionary.fromList(dict.values()); + assertEquals("[a, d, c, b, qrz]", dict2.values().toString()); + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java new file mode 100644 index 0000000..06a876e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import java.io.IOException; +import java.util.Random; + +public final class GradientMachineTest extends OnlineBaseTest { + + @Test + public void testGradientmachine() throws IOException { + Vector target = readStandardData(); + GradientMachine grad = new GradientMachine(8,4,2).learningRate(0.1).regularization(0.01); + Random gen = RandomUtils.getRandom(); + grad.initWeights(gen); + train(getInput(), target, grad); + // TODO not sure why the RNG change made this fail. Value is 0.5-1.0 no matter what seed is chosen? + test(getInput(), target, grad, 1.0, 1); + //test(getInput(), target, grad, 0.05, 1); + } + +}
