http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java new file mode 100644 index 0000000..3104cb1 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMEvaluatorTest.java @@ -0,0 +1,63 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.junit.Test; + +public class HMMEvaluatorTest extends HMMTestBase { + + /** + * Test to make sure the computed model likelihood ist valid. Included tests + * are: a) forwad == backward likelihood b) model likelihood for test seqeunce + * is the expected one from R reference + */ + @Test + public void testModelLikelihood() { + // compute alpha and beta values + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), false); + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), false); + // now test whether forward == backward likelihood + double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, false); + double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(), + beta, false); + assertEquals(forwardLikelihood, backwardLikelihood, EPSILON); + // also make sure that the likelihood matches the expected one + assertEquals(1.8425e-4, forwardLikelihood, EPSILON); + } + + /** + * Test to make sure the computed model likelihood ist valid. Included tests + * are: a) forwad == backward likelihood b) model likelihood for test seqeunce + * is the expected one from R reference + */ + @Test + public void testScaledModelLikelihood() { + // compute alpha and beta values + Matrix alpha = HmmAlgorithms.forwardAlgorithm(getModel(), getSequence(), true); + Matrix beta = HmmAlgorithms.backwardAlgorithm(getModel(), getSequence(), true); + // now test whether forward == backward likelihood + double forwardLikelihood = HmmEvaluator.modelLikelihood(alpha, true); + double backwardLikelihood = HmmEvaluator.modelLikelihood(getModel(), getSequence(), + beta, true); + assertEquals(forwardLikelihood, backwardLikelihood, EPSILON); + // also make sure that the likelihood matches the expected one + assertEquals(1.8425e-4, forwardLikelihood, EPSILON); + } + +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java new file mode 100644 index 0000000..3260f51 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMModelTest.java @@ -0,0 +1,32 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.junit.Test; + +public class HMMModelTest extends HMMTestBase { + + @Test + public void testRandomModelGeneration() { + // make sure we generate a valid random model + HmmModel model = new HmmModel(10, 20); + // check whether the model is valid + HmmUtils.validate(model); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java new file mode 100644 index 0000000..90f1cd8 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTestBase.java @@ -0,0 +1,73 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; + +public class HMMTestBase extends MahoutTestCase { + + private HmmModel model; + private final int[] sequence = {1, 0, 2, 2, 0, 0, 1}; + + /** + * We initialize a new HMM model using the following parameters # hidden + * states: 4 ("H0","H1","H2","H3") # output states: 3 ("O0","O1","O2") # + * transition matrix to: H0 H1 H2 H3 from: H0 0.5 0.1 0.1 0.3 H1 0.4 0.4 0.1 + * 0.1 H2 0.1 0.0 0.8 0.1 H3 0.1 0.1 0.1 0.7 # output matrix to: O0 O1 O2 + * from: H0 0.8 0.1 0.1 H1 0.6 0.1 0.3 H2 0.1 0.8 0.1 H3 0.0 0.1 0.9 # initial + * probabilities H0 0.2 + * <p/> + * H1 0.1 H2 0.4 H3 0.3 + * <p/> + * We also intialize an observation sequence: "O1" "O0" "O2" "O2" "O0" "O0" + * "O1" + */ + + @Override + public void setUp() throws Exception { + super.setUp(); + // intialize the hidden/output state names + String[] hiddenNames = {"H0", "H1", "H2", "H3"}; + String[] outputNames = {"O0", "O1", "O2"}; + // initialize the transition matrix + double[][] transitionP = {{0.5, 0.1, 0.1, 0.3}, {0.4, 0.4, 0.1, 0.1}, + {0.1, 0.0, 0.8, 0.1}, {0.1, 0.1, 0.1, 0.7}}; + // initialize the emission matrix + double[][] emissionP = {{0.8, 0.1, 0.1}, {0.6, 0.1, 0.3}, + {0.1, 0.8, 0.1}, {0.0, 0.1, 0.9}}; + // initialize the initial probability vector + double[] initialP = {0.2, 0.1, 0.4, 0.3}; + // now generate the model + model = new HmmModel(new DenseMatrix(transitionP), new DenseMatrix( + emissionP), new DenseVector(initialP)); + model.registerHiddenStateNames(hiddenNames); + model.registerOutputStateNames(outputNames); + // make sure the model is valid :) + HmmUtils.validate(model); + } + + protected HmmModel getModel() { + return model; + } + + protected int[] getSequence() { + return sequence; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java new file mode 100644 index 0000000..b8f3186 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMTrainerTest.java @@ -0,0 +1,163 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public class HMMTrainerTest extends HMMTestBase { + + @Test + public void testViterbiTraining() { + // initialize the expected model parameters (from R) + // expected transition matrix + double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125}, + {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429}, + {0.5, 0.1, 0.1, 0.3}}; + // initialize the emission matrix + double[][] emissionE = {{0.882353, 0.058824, 0.058824}, + {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923}, + {0.111111, 0.111111, 0.777778}}; + + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, false); + + // now check whether the model matches our expectations + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], EPSILON); + } + + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], EPSILON); + } + } + + } + + @Test + public void testScaledViterbiTraining() { + // initialize the expected model parameters (from R) + // expected transition matrix + double[][] transitionE = {{0.3125, 0.0625, 0.3125, 0.3125}, + {0.25, 0.25, 0.25, 0.25}, {0.5, 0.071429, 0.357143, 0.071429}, + {0.5, 0.1, 0.1, 0.3}}; + // initialize the emission matrix + double[][] emissionE = {{0.882353, 0.058824, 0.058824}, + {0.333333, 0.333333, 0.3333333}, {0.076923, 0.846154, 0.076923}, + {0.111111, 0.111111, 0.777778}}; + + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + HmmModel trained = HmmTrainer.trainViterbi(getModel(), observed, 0.5, 0.1, 10, + true); + + // now check whether the model matches our expectations + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), transitionE[i][j], + EPSILON); + } + + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), emissionE[i][j], + EPSILON); + } + } + + } + + @Test + public void testBaumWelchTraining() { + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + // expected values from Matlab HMM package / R HMM package + double[] initialExpected = {0, 0, 1.0, 0}; + double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683}, + {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0}, + {0.0024, 0.6657, 0, 0.3319}}; + double[][] emissionExpected = {{0.9995, 0.0004, 0.0001}, + {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}}; + + HmmModel trained = HmmTrainer.trainBaumWelch(getModel(), observed, 0.1, 10, + false); + + Vector initialProbabilities = trained.getInitialProbabilities(); + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + assertEquals(initialProbabilities.get(i), initialExpected[i], + 0.0001); + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), + transitionExpected[i][j], 0.0001); + } + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), + emissionExpected[i][j], 0.0001); + } + } + } + + @Test + public void testScaledBaumWelchTraining() { + // train the given network to the following output sequence + int[] observed = {1, 0, 2, 2, 0, 0, 1, 1, 1, 0, 2, 0, 1, 0, 0}; + + // expected values from Matlab HMM package / R HMM package + double[] initialExpected = {0, 0, 1.0, 0}; + double[][] transitionExpected = {{0.2319, 0.0993, 0.0005, 0.6683}, + {0.0001, 0.3345, 0.6654, 0}, {0.5975, 0, 0.4025, 0}, + {0.0024, 0.6657, 0, 0.3319}}; + double[][] emissionExpected = {{0.9995, 0.0004, 0.0001}, + {0.9943, 0.0036, 0.0021}, {0.0059, 0.9941, 0}, {0, 0, 1}}; + + HmmModel trained = HmmTrainer + .trainBaumWelch(getModel(), observed, 0.1, 10, true); + + Vector initialProbabilities = trained.getInitialProbabilities(); + Matrix emissionMatrix = trained.getEmissionMatrix(); + Matrix transitionMatrix = trained.getTransitionMatrix(); + + for (int i = 0; i < trained.getNrOfHiddenStates(); ++i) { + assertEquals(initialProbabilities.get(i), initialExpected[i], + 0.0001); + for (int j = 0; j < trained.getNrOfHiddenStates(); ++j) { + assertEquals(transitionMatrix.getQuick(i, j), + transitionExpected[i][j], 0.0001); + } + for (int j = 0; j < trained.getNrOfOutputStates(); ++j) { + assertEquals(emissionMatrix.getQuick(i, j), + emissionExpected[i][j], 0.0001); + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java new file mode 100644 index 0000000..6c34718 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sequencelearning/hmm/HMMUtilsTest.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sequencelearning.hmm; + +import java.util.Arrays; +import java.util.List; + +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public class HMMUtilsTest extends HMMTestBase { + + private Matrix legal22; + private Matrix legal23; + private Matrix legal33; + private Vector legal2; + private Matrix illegal22; + + @Override + public void setUp() throws Exception { + super.setUp(); + legal22 = new DenseMatrix(new double[][]{{0.5, 0.5}, {0.3, 0.7}}); + legal23 = new DenseMatrix(new double[][]{{0.2, 0.2, 0.6}, + {0.3, 0.3, 0.4}}); + legal33 = new DenseMatrix(new double[][]{{0.1, 0.1, 0.8}, + {0.1, 0.2, 0.7}, {0.2, 0.3, 0.5}}); + legal2 = new DenseVector(new double[]{0.4, 0.6}); + illegal22 = new DenseMatrix(new double[][]{{1, 2}, {3, 4}}); + } + + @Test + public void testValidatorLegal() { + HmmUtils.validate(new HmmModel(legal22, legal23, legal2)); + } + + @Test + public void testValidatorDimensionError() { + try { + HmmUtils.validate(new HmmModel(legal33, legal23, legal2)); + } catch (IllegalArgumentException e) { + // success + return; + } + fail(); + } + + @Test + public void testValidatorIllegelMatrixError() { + try { + HmmUtils.validate(new HmmModel(illegal22, legal23, legal2)); + } catch (IllegalArgumentException e) { + // success + return; + } + fail(); + } + + @Test + public void testEncodeStateSequence() { + String[] hiddenSequence = {"H1", "H2", "H0", "H3", "H4"}; + String[] outputSequence = {"O1", "O2", "O4", "O0"}; + // test encoding the hidden Sequence + int[] hiddenSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays + .asList(hiddenSequence), false, -1); + int[] outputSequenceEnc = HmmUtils.encodeStateSequence(getModel(), Arrays + .asList(outputSequence), true, -1); + // expected state sequences + int[] hiddenSequenceExp = {1, 2, 0, 3, -1}; + int[] outputSequenceExp = {1, 2, -1, 0}; + // compare + for (int i = 0; i < hiddenSequenceEnc.length; ++i) { + assertEquals(hiddenSequenceExp[i], hiddenSequenceEnc[i]); + } + for (int i = 0; i < outputSequenceEnc.length; ++i) { + assertEquals(outputSequenceExp[i], outputSequenceEnc[i]); + } + } + + @Test + public void testDecodeStateSequence() { + int[] hiddenSequence = {1, 2, 0, 3, 10}; + int[] outputSequence = {1, 2, 10, 0}; + // test encoding the hidden Sequence + List<String> hiddenSequenceDec = HmmUtils.decodeStateSequence( + getModel(), hiddenSequence, false, "unknown"); + List<String> outputSequenceDec = HmmUtils.decodeStateSequence( + getModel(), outputSequence, true, "unknown"); + // expected state sequences + String[] hiddenSequenceExp = {"H1", "H2", "H0", "H3", "unknown"}; + String[] outputSequenceExp = {"O1", "O2", "unknown", "O0"}; + // compare + for (int i = 0; i < hiddenSequenceExp.length; ++i) { + assertEquals(hiddenSequenceExp[i], hiddenSequenceDec.get(i)); + } + for (int i = 0; i < outputSequenceExp.length; ++i) { + assertEquals(outputSequenceExp[i], outputSequenceDec.get(i)); + } + } + + @Test + public void testNormalizeModel() { + DenseVector ip = new DenseVector(new double[]{10, 20}); + DenseMatrix tr = new DenseMatrix(new double[][]{{10, 10}, {20, 25}}); + DenseMatrix em = new DenseMatrix(new double[][]{{5, 7}, {10, 15}}); + HmmModel model = new HmmModel(tr, em, ip); + HmmUtils.normalizeModel(model); + // the model should be valid now + HmmUtils.validate(model); + } + + @Test + public void testTruncateModel() { + DenseVector ip = new DenseVector(new double[]{0.0001, 0.0001, 0.9998}); + DenseMatrix tr = new DenseMatrix(new double[][]{ + {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001}, + {0.0001, 0.0001, 0.9998}}); + DenseMatrix em = new DenseMatrix(new double[][]{ + {0.9998, 0.0001, 0.0001}, {0.0001, 0.9998, 0.0001}, + {0.0001, 0.0001, 0.9998}}); + HmmModel model = new HmmModel(tr, em, ip); + // now truncate the model + HmmModel sparseModel = HmmUtils.truncateModel(model, 0.01); + // first make sure this is a valid model + HmmUtils.validate(sparseModel); + // now check whether the values are as expected + Vector sparse_ip = sparseModel.getInitialProbabilities(); + Matrix sparse_tr = sparseModel.getTransitionMatrix(); + Matrix sparse_em = sparseModel.getEmissionMatrix(); + for (int i = 0; i < sparseModel.getNrOfHiddenStates(); ++i) { + assertEquals(i == 2 ? 1.0 : 0.0, sparse_ip.getQuick(i), EPSILON); + for (int j = 0; j < sparseModel.getNrOfHiddenStates(); ++j) { + if (i == j) { + assertEquals(1.0, sparse_tr.getQuick(i, j), EPSILON); + assertEquals(1.0, sparse_em.getQuick(i, j), EPSILON); + } else { + assertEquals(0.0, sparse_tr.getQuick(i, j), EPSILON); + assertEquals(0.0, sparse_em.getQuick(i, j), EPSILON); + } + } + } + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java new file mode 100644 index 0000000..7ea8cb2 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/AdaptiveLogisticRegressionTest.java @@ -0,0 +1,186 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.jet.random.Exponential; +import org.junit.Test; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering; + +import java.util.Random; + +public final class AdaptiveLogisticRegressionTest extends MahoutTestCase { + + @ThreadLeakLingering(linger=1000) + @Test + public void testTrain() { + + Random gen = RandomUtils.getRandom(); + Exponential exp = new Exponential(0.5, gen); + Vector beta = new DenseVector(200); + for (Vector.Element element : beta.all()) { + int sign = 1; + if (gen.nextDouble() < 0.5) { + sign = -1; + } + element.set(sign * exp.nextDouble()); + } + + AdaptiveLogisticRegression.Wrapper cl = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1()); + cl.update(new double[]{1.0e-5, 1}); + + for (int i = 0; i < 10000; i++) { + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + cl.train(r); + if (i % 1000 == 0) { + System.out.printf("%10d %10.3f\n", i, cl.getLearner().auc()); + } + } + assertEquals(1, cl.getLearner().auc(), 0.1); + + AdaptiveLogisticRegression adaptiveLogisticRegression = new AdaptiveLogisticRegression(2, 200, new L1()); + adaptiveLogisticRegression.setInterval(1000); + + for (int i = 0; i < 20000; i++) { + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + adaptiveLogisticRegression.train(r.getKey(), r.getActual(), r.getInstance()); + if (i % 1000 == 0 && adaptiveLogisticRegression.getBest() != null) { + System.out.printf("%10d %10.4f %10.8f %.3f\n", + i, adaptiveLogisticRegression.auc(), + Math.log10(adaptiveLogisticRegression.getBest().getMappedParams()[0]), adaptiveLogisticRegression.getBest().getMappedParams()[1]); + } + } + assertEquals(1, adaptiveLogisticRegression.auc(), 0.1); + adaptiveLogisticRegression.close(); + } + + private static AdaptiveLogisticRegression.TrainingExample getExample(int i, Random gen, Vector beta) { + Vector data = new DenseVector(200); + + for (Vector.Element element : data.all()) { + element.set(gen.nextDouble() < 0.3 ? 1 : 0); + } + + double p = 1 / (1 + Math.exp(1.5 - data.dot(beta))); + int target = 0; + if (gen.nextDouble() < p) { + target = 1; + } + return new AdaptiveLogisticRegression.TrainingExample(i, null, target, data); + } + + @Test + public void copyLearnsAsExpected() { + Random gen = RandomUtils.getRandom(); + Exponential exp = new Exponential(0.5, gen); + Vector beta = new DenseVector(200); + for (Vector.Element element : beta.all()) { + int sign = 1; + if (gen.nextDouble() < 0.5) { + sign = -1; + } + element.set(sign * exp.nextDouble()); + } + + // train one copy of a wrapped learner + AdaptiveLogisticRegression.Wrapper w = new AdaptiveLogisticRegression.Wrapper(2, 200, new L1()); + for (int i = 0; i < 3000; i++) { + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + w.train(r); + if (i % 1000 == 0) { + System.out.printf("%10d %.3f\n", i, w.getLearner().auc()); + } + } + System.out.printf("%10d %.3f\n", 3000, w.getLearner().auc()); + double auc1 = w.getLearner().auc(); + + // then switch to a copy of that learner ... progress should continue + AdaptiveLogisticRegression.Wrapper w2 = w.copy(); + + for (int i = 0; i < 5000; i++) { + if (i % 1000 == 0) { + if (i == 0) { + assertEquals("Should have started with no data", 0.5, w2.getLearner().auc(), 0.0001); + } + if (i == 1000) { + double auc2 = w2.getLearner().auc(); + assertTrue("Should have had head-start", Math.abs(auc2 - 0.5) > 0.1); + assertTrue("AUC should improve quickly on copy", auc1 < auc2); + } + System.out.printf("%10d %.3f\n", i, w2.getLearner().auc()); + } + AdaptiveLogisticRegression.TrainingExample r = getExample(i, gen, beta); + w2.train(r); + } + assertEquals("Original should not change after copy is updated", auc1, w.getLearner().auc(), 1.0e-5); + + // this improvement is really quite lenient + assertTrue("AUC should improve significantly on copy", auc1 < w2.getLearner().auc() - 0.05); + + // make sure that the copy didn't lose anything + assertEquals(auc1, w.getLearner().auc(), 0); + } + + @Test + public void stepSize() { + assertEquals(500, AdaptiveLogisticRegression.stepSize(15000, 2)); + assertEquals(2000, AdaptiveLogisticRegression.stepSize(15000, 2.6)); + assertEquals(5000, AdaptiveLogisticRegression.stepSize(24000, 2.6)); + assertEquals(10000, AdaptiveLogisticRegression.stepSize(15000, 3)); + } + + @Test + @ThreadLeakLingering(linger = 1000) + public void constantStep() { + AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1()); + lr.setInterval(5000); + assertEquals(20000, lr.nextStep(15000)); + assertEquals(20000, lr.nextStep(15001)); + assertEquals(20000, lr.nextStep(16500)); + assertEquals(20000, lr.nextStep(19999)); + lr.close(); + } + + + @Test + @ThreadLeakLingering(linger = 1000) + public void growingStep() { + AdaptiveLogisticRegression lr = new AdaptiveLogisticRegression(2, 1000, new L1()); + lr.setInterval(2000, 10000); + + // start with minimum step size + for (int i = 2000; i < 20000; i+=2000) { + assertEquals(i + 2000, lr.nextStep(i)); + } + + // then level up a bit + for (int i = 20000; i < 50000; i += 5000) { + assertEquals(i + 5000, lr.nextStep(i)); + } + + // and more, but we top out with this step size + for (int i = 50000; i < 500000; i += 10000) { + assertEquals(i + 10000, lr.nextStep(i)); + } + lr.close(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java new file mode 100644 index 0000000..6ee0ddf --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/CsvRecordFactoryTest.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.collect.ImmutableMap; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.Dictionary; +import org.junit.Test; + +public final class CsvRecordFactoryTest extends MahoutTestCase { + + @Test + public void testAddToVector() { + RecordFactory csv = new CsvRecordFactory("y", ImmutableMap.of("x1", "n", "x2", "w", "x3", "t")); + csv.firstLine("z,x1,y,x2,x3,q"); + csv.maxTargetValue(2); + + Vector v = new DenseVector(2000); + int t = csv.processLine("ignore,3.1,yes,tiger, \"this is text\",ignore", v); + assertEquals(0, t); + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(3.1, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(8.0, v.norm(1), 0); + assertEquals(1.0, v.maxValue(), 0); + + v.assign(0); + t = csv.processLine("ignore,5.3,no,line, \"and more text and more\",ignore", v); + assertEquals(1, t); + + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(5.3, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(10.339850002884626, v.norm(1), 1.0e-6); + assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6); + + v.assign(0); + t = csv.processLine("ignore,5.3,invalid,line, \"and more text and more\",ignore", v); + assertEquals(1, t); + + // should have 9 values set + assertEquals(9.0, v.norm(0), 0); + // all should be = 1 except for the 3.1 + assertEquals(5.3, v.maxValue(), 0); + v.set(v.maxValueIndex(), 0); + assertEquals(8.0, v.norm(0), 0); + assertEquals(10.339850002884626, v.norm(1), 1.0e-6); + assertEquals(1.5849625007211563, v.maxValue(), 1.0e-6); + } + + @Test + public void testDictionaryOrder() { + Dictionary dict = new Dictionary(); + + dict.intern("a"); + dict.intern("d"); + dict.intern("c"); + dict.intern("b"); + dict.intern("qrz"); + + assertEquals("[a, d, c, b, qrz]", dict.values().toString()); + + Dictionary dict2 = Dictionary.fromList(dict.values()); + assertEquals("[a, d, c, b, qrz]", dict2.values().toString()); + + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java new file mode 100644 index 0000000..06a876e --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/GradientMachineTest.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import java.io.IOException; +import java.util.Random; + +public final class GradientMachineTest extends OnlineBaseTest { + + @Test + public void testGradientmachine() throws IOException { + Vector target = readStandardData(); + GradientMachine grad = new GradientMachine(8,4,2).learningRate(0.1).regularization(0.01); + Random gen = RandomUtils.getRandom(); + grad.initWeights(gen); + train(getInput(), target, grad); + // TODO not sure why the RNG change made this fail. Value is 0.5-1.0 no matter what seed is chosen? + test(getInput(), target, grad, 1.0, 1); + //test(getInput(), target, grad, 0.05, 1); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java new file mode 100644 index 0000000..2373b9d --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Random; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering; +import com.google.common.io.Closeables; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.stats.GlobalOnlineAuc; +import org.apache.mahout.math.stats.OnlineAuc; +import org.junit.Test; + +public final class ModelSerializerTest extends MahoutTestCase { + + private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException { + ByteArrayOutputStream buf = new ByteArrayOutputStream(1000); + DataOutputStream dos = new DataOutputStream(buf); + try { + PolymorphicWritable.write(dos, m); + } finally { + Closeables.close(dos, false); + } + return PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz); + } + + @Test + public void onlineAucRoundtrip() throws IOException { + RandomUtils.useTestSeed(); + OnlineAuc auc1 = new GlobalOnlineAuc(); + Random gen = RandomUtils.getRandom(); + for (int i = 0; i < 10000; i++) { + auc1.addSample(0, gen.nextGaussian()); + auc1.addSample(1, gen.nextGaussian() + 1); + } + assertEquals(0.76, auc1.auc(), 0.01); + + OnlineAuc auc3 = roundTrip(auc1, OnlineAuc.class); + + assertEquals(auc1.auc(), auc3.auc(), 0); + + for (int i = 0; i < 1000; i++) { + auc1.addSample(0, gen.nextGaussian()); + auc1.addSample(1, gen.nextGaussian() + 1); + + auc3.addSample(0, gen.nextGaussian()); + auc3.addSample(1, gen.nextGaussian() + 1); + } + + assertEquals(auc1.auc(), auc3.auc(), 0.01); + } + + @Test + public void onlineLogisticRegressionRoundTrip() throws IOException { + OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1()); + train(olr, 100); + OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class); + assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); + + train(olr, 100); + train(olr3, 100); + + assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); + olr.close(); + olr3.close(); + } + + @Test + public void crossFoldLearnerRoundTrip() throws IOException { + CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1()); + train(learner, 100); + CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class); + double auc1 = learner.auc(); + assertTrue(auc1 > 0.85); + assertEquals(auc1, learner.auc(), 1.0e-6); + assertEquals(auc1, olr3.auc(), 1.0e-6); + + train(learner, 100); + train(learner, 100); + train(olr3, 100); + + assertEquals(learner.auc(), learner.auc(), 0.02); + assertEquals(learner.auc(), olr3.auc(), 0.02); + double auc2 = learner.auc(); + assertTrue(auc2 > auc1); + learner.close(); + olr3.close(); + } + + @ThreadLeakLingering(linger = 1000) + @Test + public void adaptiveLogisticRegressionRoundTrip() throws IOException { + AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1()); + learner.setInterval(200); + train(learner, 400); + AdaptiveLogisticRegression olr3 = roundTrip(learner, AdaptiveLogisticRegression.class); + double auc1 = learner.auc(); + assertTrue(auc1 > 0.85); + assertEquals(auc1, learner.auc(), 1.0e-6); + assertEquals(auc1, olr3.auc(), 1.0e-6); + + train(learner, 1000); + train(learner, 1000); + train(olr3, 1000); + + assertEquals(learner.auc(), learner.auc(), 0.005); + assertEquals(learner.auc(), olr3.auc(), 0.005); + double auc2 = learner.auc(); + assertTrue(String.format("%.3f > %.3f", auc2, auc1), auc2 > auc1); + learner.close(); + olr3.close(); + } + + private static void train(OnlineLearner olr, int n) { + Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5}); + Random gen = RandomUtils.getRandom(); + for (int i = 0; i < n; i++) { + Vector x = randomVector(gen, 5); + + int target = gen.nextDouble() < beta.dot(x) ? 1 : 0; + olr.train(target, x); + } + } + + private static Vector randomVector(final Random gen, int n) { + Vector x = new DenseVector(n); + x.assign(new DoubleFunction() { + @Override + public double apply(double v) { + return gen.nextGaussian(); + } + }); + return x; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java new file mode 100644 index 0000000..e0a252c --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.CharStreams; +import com.google.common.io.Resources; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.List; +import java.util.Map; +import java.util.Random; + +public abstract class OnlineBaseTest extends MahoutTestCase { + + private Matrix input; + + Matrix getInput() { + return input; + } + + Vector readStandardData() throws IOException { + // 60 test samples. First column is constant. Second and third are normally distributed from + // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a + // target variable of 0, the last 30 a target of 1. The remaining columns are are random noise. + input = readCsv("sgd.csv"); + + // regenerate the target variable + Vector target = new DenseVector(60); + target.assign(0); + target.viewPart(30, 30).assign(1); + return target; + } + + static void train(Matrix input, Vector target, OnlineLearner lr) { + RandomUtils.useTestSeed(); + Random gen = RandomUtils.getRandom(); + + // train on samples in random order (but only one pass) + for (int row : permute(gen, 60)) { + lr.train((int) target.get(row), input.viewRow(row)); + } + lr.close(); + } + + static void test(Matrix input, Vector target, AbstractVectorClassifier lr, + double expected_mean_error, double expected_absolute_error) { + // now test the accuracy + Matrix tmp = lr.classify(input); + // mean(abs(tmp - target)) + double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60; + + // max(abs(tmp - target) + double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS); + + System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError); + assertEquals(0, meanAbsoluteError , expected_mean_error); + assertEquals(0, maxAbsoluteError, expected_absolute_error); + + // convenience methods should give the same results + Vector v = lr.classifyScalar(input); + assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5); + v = lr.classifyFull(input).viewColumn(1); + assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4); + } + + /** + * Permute the integers from 0 ... max-1 + * + * @param gen The random number generator to use. + * @param max The number of integers to permute + * @return An array of jumbled integer values + */ + static int[] permute(Random gen, int max) { + int[] permutation = new int[max]; + permutation[0] = 0; + for (int i = 1; i < max; i++) { + int n = gen.nextInt(i + 1); + if (n == i) { + permutation[i] = i; + } else { + permutation[i] = permutation[n]; + permutation[n] = i; + } + } + return permutation; + } + + + /** + * Reads a file containing CSV data. This isn't implemented quite the way you might like for a + * real program, but does the job for reading test data. Most notably, it will only read numbers, + * not quoted strings. + * + * @param resourceName Where to get the data. + * @return A matrix of the results. + * @throws IOException If there is an error reading the data + */ + static Matrix readCsv(String resourceName) throws IOException { + Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \"")); + + Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8); + List<String> data = CharStreams.readLines(isr); + String first = data.get(0); + data = data.subList(1, data.size()); + + List<String> values = Lists.newArrayList(onCommas.split(first)); + Matrix r = new DenseMatrix(data.size(), values.size()); + + int column = 0; + Map<String, Integer> labels = Maps.newHashMap(); + for (String value : values) { + labels.put(value, column); + column++; + } + r.setColumnLabelBindings(labels); + + int row = 0; + for (String line : data) { + column = 0; + values = Lists.newArrayList(onCommas.split(line)); + for (String value : values) { + r.set(row, column, Double.parseDouble(value)); + column++; + } + row++; + } + + return r; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java new file mode 100644 index 0000000..44b7525 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.io.Resources; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.Dictionary; +import org.junit.Assert; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.List; +import java.util.Random; + + +public final class OnlineLogisticRegressionTest extends OnlineBaseTest { + + private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class); + + /** + * The CrossFoldLearner is probably the best learner to use for new applications. + * + * @throws IOException If test resources aren't readable. + */ + @Test + public void crossValidation() throws IOException { + Vector target = readStandardData(); + + CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1()) + .lambda(1 * 1.0e-3) + .learningRate(50); + + + train(getInput(), target, lr); + + System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood()); + test(getInput(), target, lr, 0.05, 0.3); + + } + + @Test + public void crossValidatedAuc() throws IOException { + RandomUtils.useTestSeed(); + Random gen = RandomUtils.getRandom(); + + Matrix data = readCsv("cancer.csv"); + CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1()) + .stepOffset(10) + .decayExponent(0.7) + .lambda(1 * 1.0e-3) + .learningRate(5); + int k = 0; + int[] ordering = permute(gen, data.numRows()); + for (int epoch = 0; epoch < 100; epoch++) { + for (int row : ordering) { + lr.train(row, (int) data.get(row, 9), data.viewRow(row)); + System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc()); + } + assertEquals(1, lr.auc(), 0.2); + } + assertEquals(1, lr.auc(), 0.1); + } + + /** + * Verifies that a classifier with known coefficients does the right thing. + */ + @Test + public void testClassify() { + OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1)); + // set up some internal coefficients as if we had learned them + lr.setBeta(0, 0, -1); + lr.setBeta(1, 0, -2); + + // zero vector gives no information. All classes are equal. + Vector v = lr.classify(new DenseVector(new double[]{0, 0})); + assertEquals(1 / 3.0, v.get(0), 1.0e-8); + assertEquals(1 / 3.0, v.get(1), 1.0e-8); + + v = lr.classifyFull(new DenseVector(new double[]{0, 0})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(1 / 3.0, v.get(0), 1.0e-8); + assertEquals(1 / 3.0, v.get(1), 1.0e-8); + assertEquals(1 / 3.0, v.get(2), 1.0e-8); + + // weights for second vector component are still zero so all classifications are equally likely + v = lr.classify(new DenseVector(new double[]{0, 1})); + assertEquals(1 / 3.0, v.get(0), 1.0e-3); + assertEquals(1 / 3.0, v.get(1), 1.0e-3); + + v = lr.classifyFull(new DenseVector(new double[]{0, 1})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(1 / 3.0, v.get(0), 1.0e-3); + assertEquals(1 / 3.0, v.get(1), 1.0e-3); + assertEquals(1 / 3.0, v.get(2), 1.0e-3); + + // but the weights on the first component are non-zero + v = lr.classify(new DenseVector(new double[]{1, 0})); + assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); + assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); + + v = lr.classifyFull(new DenseVector(new double[]{1, 0})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); + assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); + assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8); + + lr.setBeta(0, 1, 1); + + v = lr.classifyFull(new DenseVector(new double[]{1, 1})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3); + assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3); + assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3); + + lr.setBeta(1, 1, 3); + + v = lr.classifyFull(new DenseVector(new double[]{1, 1})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8); + assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8); + assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8); + } + + @Test + public void iris() throws IOException { + // this test trains a 3-way classifier on the famous Iris dataset. + // a similar exercise can be accomplished in R using this code: + // library(nnet) + // correct = rep(0,100) + // for (j in 1:100) { + // i = order(runif(150)) + // train = iris[i[1:100],] + // test = iris[i[101:150],] + // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) + // correct[j] = mean(predict(m, newdata=test) == test$Species) + // } + // hist(correct) + // + // Note that depending on the training/test split, performance can be better or worse. + // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy + // of 100% + // + // This test uses a deterministic split that is neither outstandingly good nor bad + + + RandomUtils.useTestSeed(); + Splitter onComma = Splitter.on(","); + + // read the data + List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); + + // holds features + List<Vector> data = Lists.newArrayList(); + + // holds target variable + List<Integer> target = Lists.newArrayList(); + + // for decoding target values + Dictionary dict = new Dictionary(); + + // for permuting data later + List<Integer> order = Lists.newArrayList(); + + for (String line : raw.subList(1, raw.size())) { + // order gets a list of indexes + order.add(order.size()); + + // parse the predictor variables + Vector v = new DenseVector(5); + v.set(0, 1); + int i = 1; + Iterable<String> values = onComma.split(line); + for (String value : Iterables.limit(values, 4)) { + v.set(i++, Double.parseDouble(value)); + } + data.add(v); + + // and the target + target.add(dict.intern(Iterables.get(values, 4))); + } + + // randomize the order ... original data has each species all together + // note that this randomization is deterministic + Random random = RandomUtils.getRandom(); + Collections.shuffle(order, random); + + // select training and test data + List<Integer> train = order.subList(0, 100); + List<Integer> test = order.subList(100, 150); + logger.warn("Training set = {}", train); + logger.warn("Test set = {}", test); + + // now train many times and collect information on accuracy each time + int[] correct = new int[test.size() + 1]; + for (int run = 0; run < 200; run++) { + OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); + // 30 training passes should converge to > 95% accuracy nearly always but never to 100% + for (int pass = 0; pass < 30; pass++) { + Collections.shuffle(train, random); + for (int k : train) { + lr.train(target.get(k), data.get(k)); + } + } + + // check the accuracy on held out data + int x = 0; + int[] count = new int[3]; + for (Integer k : test) { + int r = lr.classifyFull(data.get(k)).maxValueIndex(); + count[r]++; + x += r == target.get(k) ? 1 : 0; + } + correct[x]++; + } + + // verify we never saw worse than 95% correct, + for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { + assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]); + } + // nor perfect + assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]); + } + + @Test + public void testTrain() throws Exception { + Vector target = readStandardData(); + + + // lambda here needs to be relatively small to avoid swamping the actual signal, but can be + // larger than usual because the data are dense. The learning rate doesn't matter too much + // for this example, but should generally be < 1 + // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias + // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n + OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) + .lambda(1 * 1.0e-3) + .learningRate(50); + + train(getInput(), target, lr); + test(getInput(), target, lr, 0.05, 0.3); + } + + /** + * Test for Serialization/DeSerialization + * + */ + @Test + public void testSerializationAndDeSerialization() throws Exception { + OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) + .lambda(1 * 1.0e-3) + .stepOffset(11) + .alpha(0.01) + .learningRate(50) + .decayExponent(-0.02); + + lr.close(); + + byte[] output; + + try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) { + PolymorphicWritable.write(dataOutputStream, lr); + output = byteArrayOutputStream.toByteArray(); + } + + OnlineLogisticRegression read; + + try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output); + DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) { + read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class); + } + + //lambda + Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7); + + // Reflection to get private variables + //stepOffset + Field stepOffset = lr.getClass().getDeclaredField("stepOffset"); + stepOffset.setAccessible(true); + int stepOffsetVal = (Integer) stepOffset.get(lr); + Assert.assertEquals(11, stepOffsetVal); + + //decayFactor (alpha) + Field decayFactor = lr.getClass().getDeclaredField("decayFactor"); + decayFactor.setAccessible(true); + double decayFactorVal = (Double) decayFactor.get(lr); + Assert.assertEquals(0.01, decayFactorVal, 1.0e-7); + + //learning rate (mu0) + Field mu0 = lr.getClass().getDeclaredField("mu0"); + mu0.setAccessible(true); + double mu0Val = (Double) mu0.get(lr); + Assert.assertEquals(50, mu0Val, 1.0e-7); + + //forgettingExponent (decayExponent) + Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent"); + forgettingExponent.setAccessible(true); + double forgettingExponentVal = (Double) forgettingExponent.get(lr); + Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java new file mode 100644 index 0000000..df97d38 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import java.io.IOException; + +public final class PassiveAggressiveTest extends OnlineBaseTest { + + @Test + public void testPassiveAggressive() throws IOException { + Vector target = readStandardData(); + PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1); + train(getInput(), target, pa); + test(getInput(), target, pa, 0.11, 0.31); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java new file mode 100644 index 0000000..62e10c6 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java @@ -0,0 +1,152 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import java.io.IOException; +import java.util.Random; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.stats.Sampler; + +public final class ClusteringTestUtils { + + private ClusteringTestUtils() { + } + + public static void writePointsToFile(Iterable<VectorWritable> points, + Path path, + FileSystem fs, + Configuration conf) throws IOException { + writePointsToFile(points, false, path, fs, conf); + } + + public static void writePointsToFile(Iterable<VectorWritable> points, + boolean intWritable, + Path path, + FileSystem fs, + Configuration conf) throws IOException { + SequenceFile.Writer writer = new SequenceFile.Writer(fs, + conf, + path, + intWritable ? IntWritable.class : LongWritable.class, + VectorWritable.class); + try { + int recNum = 0; + for (VectorWritable point : points) { + writer.append(intWritable ? new IntWritable(recNum++) : new LongWritable(recNum++), point); + } + } finally { + Closeables.close(writer, false); + } + } + + public static Matrix sampledCorpus(Matrix matrix, Random random, + int numDocs, int numSamples, int numTopicsPerDoc) { + Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols()); + LDASampler modelSampler = new LDASampler(matrix, random); + Vector topicVector = new DenseVector(matrix.numRows()); + for (int i = 0; i < numTopicsPerDoc; i++) { + int topic = random.nextInt(topicVector.size()); + topicVector.set(topic, topicVector.get(topic) + 1); + } + for (int docId = 0; docId < numDocs; docId++) { + for (int sample : modelSampler.sample(topicVector, numSamples)) { + corpus.set(docId, sample, corpus.get(docId, sample) + 1); + } + } + return corpus; + } + + public static Matrix randomStructuredModel(int numTopics, int numTerms) { + return randomStructuredModel(numTopics, numTerms, new DoubleFunction() { + @Override public double apply(double d) { + return 1.0 / (1 + Math.abs(d)); + } + }); + } + + public static Matrix randomStructuredModel(int numTopics, int numTerms, DoubleFunction decay) { + Matrix model = new DenseMatrix(numTopics, numTerms); + int width = numTerms / numTopics; + for (int topic = 0; topic < numTopics; topic++) { + int topicCentroid = width * (1+topic); + for (int i = 0; i < numTerms; i++) { + int distance = Math.abs(topicCentroid - i); + if (distance > numTerms / 2) { + distance = numTerms - distance; + } + double v = decay.apply(distance); + model.set(topic, i, v); + } + } + return model; + } + + /** + * Takes in a {@link Matrix} of topic distributions (such as generated by {@link org.apache.mahout.clustering.lda.cvb.CVB0Driver} or + * {@link org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0}, and constructs + * a set of samplers over this distribution, which may be sampled from by providing a distribution + * over topics, and a number of samples desired + */ + static class LDASampler { + private final Random random; + private final Sampler[] samplers; + + LDASampler(Matrix model, Random random) { + this.random = random; + samplers = new Sampler[model.numRows()]; + for (int i = 0; i < samplers.length; i++) { + samplers[i] = new Sampler(random, model.viewRow(i)); + } + } + + /** + * + * @param topicDistribution vector of p(topicId) for all topicId < model.numTopics() + * @param numSamples the number of times to sample (with replacement) from the model + * @return array of length numSamples, with each entry being a sample from the model. There + * may be repeats + */ + public int[] sample(Vector topicDistribution, int numSamples) { + Preconditions.checkNotNull(topicDistribution); + Preconditions.checkArgument(numSamples > 0, "numSamples must be positive"); + Preconditions.checkArgument(topicDistribution.size() == samplers.length, + "topicDistribution must have same cardinality as the sampling model"); + int[] samples = new int[numSamples]; + Sampler topicSampler = new Sampler(random, topicDistribution); + for (int i = 0; i < numSamples; i++) { + samples[i] = samplers[topicSampler.sample()].sample(); + } + return samples; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java new file mode 100644 index 0000000..1cbfb02 --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class TestClusterInterface extends MahoutTestCase { + + private static final DistanceMeasure measure = new ManhattanDistanceMeasure(); + + @Test + public void testClusterAsFormatString() { + double[] d = { 1.1, 2.2, 3.3 }; + Vector m = new DenseVector(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String formatString = cluster.asFormatString(null); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[1.1,2.2,3.3]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + + @Test + public void testClusterAsFormatStringSparse() { + double[] d = { 1.1, 0.0, 3.3 }; + Vector m = new SequentialAccessSparseVector(3); + m.assign(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String formatString = cluster.asFormatString(null); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + + @Test + public void testClusterAsFormatStringWithBindings() { + double[] d = { 1.1, 2.2, 3.3 }; + Vector m = new DenseVector(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String[] bindings = { "fee", null, "foo" }; + String formatString = cluster.asFormatString(bindings); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[{\"fee\":1.1},{\"1\":2.2},{\"foo\":3.3}]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + + @Test + public void testClusterAsFormatStringSparseWithBindings() { + double[] d = { 1.1, 0.0, 3.3 }; + Vector m = new SequentialAccessSparseVector(3); + m.assign(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String formatString = cluster.asFormatString(null); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java ---------------------------------------------------------------------- diff --git a/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java new file mode 100644 index 0000000..43417fc --- /dev/null +++ b/mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering; + +import java.util.Collection; + +import com.google.common.collect.Lists; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.SquareRootFunction; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class TestGaussianAccumulators extends MahoutTestCase { + + private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class); + + private Collection<VectorWritable> sampleData = Lists.newArrayList(); + private int sampleN; + private Vector sampleMean; + private Vector sampleStd; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + sampleData = Lists.newArrayList(); + generateSamples(); + sampleN = 0; + Vector sum = new DenseVector(2); + for (VectorWritable v : sampleData) { + sum.assign(v.get(), Functions.PLUS); + sampleN++; + } + sampleMean = sum.divide(sampleN); + + Vector sampleVar = new DenseVector(2); + for (VectorWritable v : sampleData) { + Vector delta = v.get().minus(sampleMean); + sampleVar.assign(delta.times(delta), Functions.PLUS); + } + sampleVar = sampleVar.divide(sampleN - 1); + sampleStd = sampleVar.clone(); + sampleStd.assign(new SquareRootFunction()); + log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]", + sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0), sampleStd.get(1)); + } + + /** + * Generate random samples and add them to the sampleData + * + * @param num + * int number of samples to generate + * @param mx + * double x-value of the sample mean + * @param my + * double y-value of the sample mean + * @param sdx + * double x-value standard deviation of the samples + * @param sdy + * double y-value standard deviation of the samples + */ + private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) { + log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy); + for (int i = 0; i < num; i++) { + sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx, sdx), + UncommonDistributions.rNorm(my, sdy) }))); + } + } + + private void generateSamples() { + generate2dSamples(50000, 1, 2, 3, 4); + } + + @Test + public void testAccumulatorNoSamples() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean(), accumulator1.getMean()); + assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON); + } + + @Test + public void testAccumulatorOneSample() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + Vector sample = new DenseVector(2); + accumulator0.observe(sample, 1.0); + accumulator1.observe(sample, 1.0); + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean(), accumulator1.getMean()); + assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON); + } + + @Test + public void testOLAccumulatorResults() { + GaussianAccumulator accumulator = new OnlineGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator.observe(vw.get(), 1.0); + } + accumulator.compute(); + log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]", + accumulator.getN(), + accumulator.getMean().get(0), + accumulator.getMean().get(1), + accumulator.getStd().get(0), + accumulator.getStd().get(1)); + assertEquals("OL N", sampleN, accumulator.getN(), EPSILON); + assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON); + assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON); + } + + @Test + public void testRSAccumulatorResults() { + GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator.observe(vw.get(), 1.0); + } + accumulator.compute(); + log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]", + (int) accumulator.getN(), + accumulator.getMean().get(0), + accumulator.getMean().get(1), + accumulator.getStd().get(0), + accumulator.getStd().get(1)); + assertEquals("OL N", sampleN, accumulator.getN(), EPSILON); + assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON); + assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001); + } + + @Test + public void testAccumulatorWeightedResults() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator0.observe(vw.get(), 0.5); + accumulator1.observe(vw.get(), 0.5); + } + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON); + assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001); + assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01); + } + + @Test + public void testAccumulatorWeightedResults2() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator0.observe(vw.get(), 1.5); + accumulator1.observe(vw.get(), 1.5); + } + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON); + assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001); + assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01); + } +}
