http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java new file mode 100644 index 0000000..2373b9d --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/ModelSerializerTest.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.Random; + +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakLingering; +import com.google.common.io.Closeables; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.stats.GlobalOnlineAuc; +import org.apache.mahout.math.stats.OnlineAuc; +import org.junit.Test; + +public final class ModelSerializerTest extends MahoutTestCase { + + private static <T extends Writable> T roundTrip(T m, Class<T> clazz) throws IOException { + ByteArrayOutputStream buf = new ByteArrayOutputStream(1000); + DataOutputStream dos = new DataOutputStream(buf); + try { + PolymorphicWritable.write(dos, m); + } finally { + Closeables.close(dos, false); + } + return PolymorphicWritable.read(new DataInputStream(new ByteArrayInputStream(buf.toByteArray())), clazz); + } + + @Test + public void onlineAucRoundtrip() throws IOException { + RandomUtils.useTestSeed(); + OnlineAuc auc1 = new GlobalOnlineAuc(); + Random gen = RandomUtils.getRandom(); + for (int i = 0; i < 10000; i++) { + auc1.addSample(0, gen.nextGaussian()); + auc1.addSample(1, gen.nextGaussian() + 1); + } + assertEquals(0.76, auc1.auc(), 0.01); + + OnlineAuc auc3 = roundTrip(auc1, OnlineAuc.class); + + assertEquals(auc1.auc(), auc3.auc(), 0); + + for (int i = 0; i < 1000; i++) { + auc1.addSample(0, gen.nextGaussian()); + auc1.addSample(1, gen.nextGaussian() + 1); + + auc3.addSample(0, gen.nextGaussian()); + auc3.addSample(1, gen.nextGaussian() + 1); + } + + assertEquals(auc1.auc(), auc3.auc(), 0.01); + } + + @Test + public void onlineLogisticRegressionRoundTrip() throws IOException { + OnlineLogisticRegression olr = new OnlineLogisticRegression(2, 5, new L1()); + train(olr, 100); + OnlineLogisticRegression olr3 = roundTrip(olr, OnlineLogisticRegression.class); + assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); + + train(olr, 100); + train(olr3, 100); + + assertEquals(0, olr.getBeta().minus(olr3.getBeta()).aggregate(Functions.MAX, Functions.IDENTITY), 1.0e-6); + olr.close(); + olr3.close(); + } + + @Test + public void crossFoldLearnerRoundTrip() throws IOException { + CrossFoldLearner learner = new CrossFoldLearner(5, 2, 5, new L1()); + train(learner, 100); + CrossFoldLearner olr3 = roundTrip(learner, CrossFoldLearner.class); + double auc1 = learner.auc(); + assertTrue(auc1 > 0.85); + assertEquals(auc1, learner.auc(), 1.0e-6); + assertEquals(auc1, olr3.auc(), 1.0e-6); + + train(learner, 100); + train(learner, 100); + train(olr3, 100); + + assertEquals(learner.auc(), learner.auc(), 0.02); + assertEquals(learner.auc(), olr3.auc(), 0.02); + double auc2 = learner.auc(); + assertTrue(auc2 > auc1); + learner.close(); + olr3.close(); + } + + @ThreadLeakLingering(linger = 1000) + @Test + public void adaptiveLogisticRegressionRoundTrip() throws IOException { + AdaptiveLogisticRegression learner = new AdaptiveLogisticRegression(2, 5, new L1()); + learner.setInterval(200); + train(learner, 400); + AdaptiveLogisticRegression olr3 = roundTrip(learner, AdaptiveLogisticRegression.class); + double auc1 = learner.auc(); + assertTrue(auc1 > 0.85); + assertEquals(auc1, learner.auc(), 1.0e-6); + assertEquals(auc1, olr3.auc(), 1.0e-6); + + train(learner, 1000); + train(learner, 1000); + train(olr3, 1000); + + assertEquals(learner.auc(), learner.auc(), 0.005); + assertEquals(learner.auc(), olr3.auc(), 0.005); + double auc2 = learner.auc(); + assertTrue(String.format("%.3f > %.3f", auc2, auc1), auc2 > auc1); + learner.close(); + olr3.close(); + } + + private static void train(OnlineLearner olr, int n) { + Vector beta = new DenseVector(new double[]{1, -1, 0, 0.5, -0.5}); + Random gen = RandomUtils.getRandom(); + for (int i = 0; i < n; i++) { + Vector x = randomVector(gen, 5); + + int target = gen.nextDouble() < beta.dot(x) ? 1 : 0; + olr.train(target, x); + } + } + + private static Vector randomVector(final Random gen, int n) { + Vector x = new DenseVector(n); + x.assign(new DoubleFunction() { + @Override + public double apply(double v) { + return gen.nextGaussian(); + } + }); + return x; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java new file mode 100644 index 0000000..e0a252c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineBaseTest.java @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.CharMatcher; +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import com.google.common.io.CharStreams; +import com.google.common.io.Resources; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.function.Functions; + +import java.io.IOException; +import java.io.InputStreamReader; +import java.util.List; +import java.util.Map; +import java.util.Random; + +public abstract class OnlineBaseTest extends MahoutTestCase { + + private Matrix input; + + Matrix getInput() { + return input; + } + + Vector readStandardData() throws IOException { + // 60 test samples. First column is constant. Second and third are normally distributed from + // either N([2,2], 1) (rows 0...29) or N([-2,-2], 1) (rows 30...59). The first 30 rows have a + // target variable of 0, the last 30 a target of 1. The remaining columns are are random noise. + input = readCsv("sgd.csv"); + + // regenerate the target variable + Vector target = new DenseVector(60); + target.assign(0); + target.viewPart(30, 30).assign(1); + return target; + } + + static void train(Matrix input, Vector target, OnlineLearner lr) { + RandomUtils.useTestSeed(); + Random gen = RandomUtils.getRandom(); + + // train on samples in random order (but only one pass) + for (int row : permute(gen, 60)) { + lr.train((int) target.get(row), input.viewRow(row)); + } + lr.close(); + } + + static void test(Matrix input, Vector target, AbstractVectorClassifier lr, + double expected_mean_error, double expected_absolute_error) { + // now test the accuracy + Matrix tmp = lr.classify(input); + // mean(abs(tmp - target)) + double meanAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.PLUS, Functions.ABS) / 60; + + // max(abs(tmp - target) + double maxAbsoluteError = tmp.viewColumn(0).minus(target).aggregate(Functions.MAX, Functions.ABS); + + System.out.printf("mAE = %.4f, maxAE = %.4f\n", meanAbsoluteError, maxAbsoluteError); + assertEquals(0, meanAbsoluteError , expected_mean_error); + assertEquals(0, maxAbsoluteError, expected_absolute_error); + + // convenience methods should give the same results + Vector v = lr.classifyScalar(input); + assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-5); + v = lr.classifyFull(input).viewColumn(1); + assertEquals(0, v.minus(tmp.viewColumn(0)).norm(1), 1.0e-4); + } + + /** + * Permute the integers from 0 ... max-1 + * + * @param gen The random number generator to use. + * @param max The number of integers to permute + * @return An array of jumbled integer values + */ + static int[] permute(Random gen, int max) { + int[] permutation = new int[max]; + permutation[0] = 0; + for (int i = 1; i < max; i++) { + int n = gen.nextInt(i + 1); + if (n == i) { + permutation[i] = i; + } else { + permutation[i] = permutation[n]; + permutation[n] = i; + } + } + return permutation; + } + + + /** + * Reads a file containing CSV data. This isn't implemented quite the way you might like for a + * real program, but does the job for reading test data. Most notably, it will only read numbers, + * not quoted strings. + * + * @param resourceName Where to get the data. + * @return A matrix of the results. + * @throws IOException If there is an error reading the data + */ + static Matrix readCsv(String resourceName) throws IOException { + Splitter onCommas = Splitter.on(',').trimResults(CharMatcher.anyOf(" \"")); + + Readable isr = new InputStreamReader(Resources.getResource(resourceName).openStream(), Charsets.UTF_8); + List<String> data = CharStreams.readLines(isr); + String first = data.get(0); + data = data.subList(1, data.size()); + + List<String> values = Lists.newArrayList(onCommas.split(first)); + Matrix r = new DenseMatrix(data.size(), values.size()); + + int column = 0; + Map<String, Integer> labels = Maps.newHashMap(); + for (String value : values) { + labels.put(value, column); + column++; + } + r.setColumnLabelBindings(labels); + + int row = 0; + for (String line : data) { + column = 0; + values = Lists.newArrayList(onCommas.split(line)); + for (String value : values) { + r.set(row, column, Double.parseDouble(value)); + column++; + } + row++; + } + + return r; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java new file mode 100644 index 0000000..44b7525 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/OnlineLogisticRegressionTest.java @@ -0,0 +1,330 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import com.google.common.base.Charsets; +import com.google.common.base.Splitter; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import com.google.common.io.Resources; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.vectorizer.encoders.Dictionary; +import org.junit.Assert; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.lang.reflect.Field; +import java.util.Collections; +import java.util.List; +import java.util.Random; + + +public final class OnlineLogisticRegressionTest extends OnlineBaseTest { + + private static final Logger logger = LoggerFactory.getLogger(OnlineLogisticRegressionTest.class); + + /** + * The CrossFoldLearner is probably the best learner to use for new applications. + * + * @throws IOException If test resources aren't readable. + */ + @Test + public void crossValidation() throws IOException { + Vector target = readStandardData(); + + CrossFoldLearner lr = new CrossFoldLearner(5, 2, 8, new L1()) + .lambda(1 * 1.0e-3) + .learningRate(50); + + + train(getInput(), target, lr); + + System.out.printf("%.2f %.5f\n", lr.auc(), lr.logLikelihood()); + test(getInput(), target, lr, 0.05, 0.3); + + } + + @Test + public void crossValidatedAuc() throws IOException { + RandomUtils.useTestSeed(); + Random gen = RandomUtils.getRandom(); + + Matrix data = readCsv("cancer.csv"); + CrossFoldLearner lr = new CrossFoldLearner(5, 2, 10, new L1()) + .stepOffset(10) + .decayExponent(0.7) + .lambda(1 * 1.0e-3) + .learningRate(5); + int k = 0; + int[] ordering = permute(gen, data.numRows()); + for (int epoch = 0; epoch < 100; epoch++) { + for (int row : ordering) { + lr.train(row, (int) data.get(row, 9), data.viewRow(row)); + System.out.printf("%d,%d,%.3f\n", epoch, k++, lr.auc()); + } + assertEquals(1, lr.auc(), 0.2); + } + assertEquals(1, lr.auc(), 0.1); + } + + /** + * Verifies that a classifier with known coefficients does the right thing. + */ + @Test + public void testClassify() { + OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 2, new L2(1)); + // set up some internal coefficients as if we had learned them + lr.setBeta(0, 0, -1); + lr.setBeta(1, 0, -2); + + // zero vector gives no information. All classes are equal. + Vector v = lr.classify(new DenseVector(new double[]{0, 0})); + assertEquals(1 / 3.0, v.get(0), 1.0e-8); + assertEquals(1 / 3.0, v.get(1), 1.0e-8); + + v = lr.classifyFull(new DenseVector(new double[]{0, 0})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(1 / 3.0, v.get(0), 1.0e-8); + assertEquals(1 / 3.0, v.get(1), 1.0e-8); + assertEquals(1 / 3.0, v.get(2), 1.0e-8); + + // weights for second vector component are still zero so all classifications are equally likely + v = lr.classify(new DenseVector(new double[]{0, 1})); + assertEquals(1 / 3.0, v.get(0), 1.0e-3); + assertEquals(1 / 3.0, v.get(1), 1.0e-3); + + v = lr.classifyFull(new DenseVector(new double[]{0, 1})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(1 / 3.0, v.get(0), 1.0e-3); + assertEquals(1 / 3.0, v.get(1), 1.0e-3); + assertEquals(1 / 3.0, v.get(2), 1.0e-3); + + // but the weights on the first component are non-zero + v = lr.classify(new DenseVector(new double[]{1, 0})); + assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); + assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); + + v = lr.classifyFull(new DenseVector(new double[]{1, 0})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(1 / (1 + Math.exp(-1) + Math.exp(-2)), v.get(0), 1.0e-8); + assertEquals(Math.exp(-1) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(1), 1.0e-8); + assertEquals(Math.exp(-2) / (1 + Math.exp(-1) + Math.exp(-2)), v.get(2), 1.0e-8); + + lr.setBeta(0, 1, 1); + + v = lr.classifyFull(new DenseVector(new double[]{1, 1})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(-2)), v.get(1), 1.0e-3); + assertEquals(Math.exp(-2) / (1 + Math.exp(0) + Math.exp(-2)), v.get(2), 1.0e-3); + assertEquals(1 / (1 + Math.exp(0) + Math.exp(-2)), v.get(0), 1.0e-3); + + lr.setBeta(1, 1, 3); + + v = lr.classifyFull(new DenseVector(new double[]{1, 1})); + assertEquals(1.0, v.zSum(), 1.0e-8); + assertEquals(Math.exp(0) / (1 + Math.exp(0) + Math.exp(1)), v.get(1), 1.0e-8); + assertEquals(Math.exp(1) / (1 + Math.exp(0) + Math.exp(1)), v.get(2), 1.0e-8); + assertEquals(1 / (1 + Math.exp(0) + Math.exp(1)), v.get(0), 1.0e-8); + } + + @Test + public void iris() throws IOException { + // this test trains a 3-way classifier on the famous Iris dataset. + // a similar exercise can be accomplished in R using this code: + // library(nnet) + // correct = rep(0,100) + // for (j in 1:100) { + // i = order(runif(150)) + // train = iris[i[1:100],] + // test = iris[i[101:150],] + // m = multinom(Species ~ Sepal.Length + Sepal.Width + Petal.Length + Petal.Width, train) + // correct[j] = mean(predict(m, newdata=test) == test$Species) + // } + // hist(correct) + // + // Note that depending on the training/test split, performance can be better or worse. + // There is about a 5% chance of getting accuracy < 90% and about 20% chance of getting accuracy + // of 100% + // + // This test uses a deterministic split that is neither outstandingly good nor bad + + + RandomUtils.useTestSeed(); + Splitter onComma = Splitter.on(","); + + // read the data + List<String> raw = Resources.readLines(Resources.getResource("iris.csv"), Charsets.UTF_8); + + // holds features + List<Vector> data = Lists.newArrayList(); + + // holds target variable + List<Integer> target = Lists.newArrayList(); + + // for decoding target values + Dictionary dict = new Dictionary(); + + // for permuting data later + List<Integer> order = Lists.newArrayList(); + + for (String line : raw.subList(1, raw.size())) { + // order gets a list of indexes + order.add(order.size()); + + // parse the predictor variables + Vector v = new DenseVector(5); + v.set(0, 1); + int i = 1; + Iterable<String> values = onComma.split(line); + for (String value : Iterables.limit(values, 4)) { + v.set(i++, Double.parseDouble(value)); + } + data.add(v); + + // and the target + target.add(dict.intern(Iterables.get(values, 4))); + } + + // randomize the order ... original data has each species all together + // note that this randomization is deterministic + Random random = RandomUtils.getRandom(); + Collections.shuffle(order, random); + + // select training and test data + List<Integer> train = order.subList(0, 100); + List<Integer> test = order.subList(100, 150); + logger.warn("Training set = {}", train); + logger.warn("Test set = {}", test); + + // now train many times and collect information on accuracy each time + int[] correct = new int[test.size() + 1]; + for (int run = 0; run < 200; run++) { + OnlineLogisticRegression lr = new OnlineLogisticRegression(3, 5, new L2(1)); + // 30 training passes should converge to > 95% accuracy nearly always but never to 100% + for (int pass = 0; pass < 30; pass++) { + Collections.shuffle(train, random); + for (int k : train) { + lr.train(target.get(k), data.get(k)); + } + } + + // check the accuracy on held out data + int x = 0; + int[] count = new int[3]; + for (Integer k : test) { + int r = lr.classifyFull(data.get(k)).maxValueIndex(); + count[r]++; + x += r == target.get(k) ? 1 : 0; + } + correct[x]++; + } + + // verify we never saw worse than 95% correct, + for (int i = 0; i < Math.floor(0.95 * test.size()); i++) { + assertEquals(String.format("%d trials had unacceptable accuracy of only %.0f%%: ", correct[i], 100.0 * i / test.size()), 0, correct[i]); + } + // nor perfect + assertEquals(String.format("%d trials had unrealistic accuracy of 100%%", correct[test.size() - 1]), 0, correct[test.size()]); + } + + @Test + public void testTrain() throws Exception { + Vector target = readStandardData(); + + + // lambda here needs to be relatively small to avoid swamping the actual signal, but can be + // larger than usual because the data are dense. The learning rate doesn't matter too much + // for this example, but should generally be < 1 + // --passes 1 --rate 50 --lambda 0.001 --input sgd-y.csv --features 21 --output model --noBias + // --target y --categories 2 --predictors V2 V3 V4 V5 V6 V7 --types n + OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) + .lambda(1 * 1.0e-3) + .learningRate(50); + + train(getInput(), target, lr); + test(getInput(), target, lr, 0.05, 0.3); + } + + /** + * Test for Serialization/DeSerialization + * + */ + @Test + public void testSerializationAndDeSerialization() throws Exception { + OnlineLogisticRegression lr = new OnlineLogisticRegression(2, 8, new L1()) + .lambda(1 * 1.0e-3) + .stepOffset(11) + .alpha(0.01) + .learningRate(50) + .decayExponent(-0.02); + + lr.close(); + + byte[] output; + + try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(); + DataOutputStream dataOutputStream = new DataOutputStream(byteArrayOutputStream)) { + PolymorphicWritable.write(dataOutputStream, lr); + output = byteArrayOutputStream.toByteArray(); + } + + OnlineLogisticRegression read; + + try (ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(output); + DataInputStream dataInputStream = new DataInputStream(byteArrayInputStream)) { + read = PolymorphicWritable.read(dataInputStream, OnlineLogisticRegression.class); + } + + //lambda + Assert.assertEquals((1.0e-3), read.getLambda(), 1.0e-7); + + // Reflection to get private variables + //stepOffset + Field stepOffset = lr.getClass().getDeclaredField("stepOffset"); + stepOffset.setAccessible(true); + int stepOffsetVal = (Integer) stepOffset.get(lr); + Assert.assertEquals(11, stepOffsetVal); + + //decayFactor (alpha) + Field decayFactor = lr.getClass().getDeclaredField("decayFactor"); + decayFactor.setAccessible(true); + double decayFactorVal = (Double) decayFactor.get(lr); + Assert.assertEquals(0.01, decayFactorVal, 1.0e-7); + + //learning rate (mu0) + Field mu0 = lr.getClass().getDeclaredField("mu0"); + mu0.setAccessible(true); + double mu0Val = (Double) mu0.get(lr); + Assert.assertEquals(50, mu0Val, 1.0e-7); + + //forgettingExponent (decayExponent) + Field forgettingExponent = lr.getClass().getDeclaredField("forgettingExponent"); + forgettingExponent.setAccessible(true); + double forgettingExponentVal = (Double) forgettingExponent.get(lr); + Assert.assertEquals(-0.02, forgettingExponentVal, 1.0e-7); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java new file mode 100644 index 0000000..df97d38 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/classifier/sgd/PassiveAggressiveTest.java @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.classifier.sgd; + +import org.apache.mahout.math.Vector; +import org.junit.Test; + +import java.io.IOException; + +public final class PassiveAggressiveTest extends OnlineBaseTest { + + @Test + public void testPassiveAggressive() throws IOException { + Vector target = readStandardData(); + PassiveAggressive pa = new PassiveAggressive(2,8).learningRate(0.1); + train(getInput(), target, pa); + test(getInput(), target, pa, 0.11, 0.31); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java new file mode 100644 index 0000000..62e10c6 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/ClusteringTestUtils.java @@ -0,0 +1,152 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import java.io.IOException; +import java.util.Random; + +import com.google.common.base.Preconditions; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SparseRowMatrix; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.DoubleFunction; +import org.apache.mahout.math.stats.Sampler; + +public final class ClusteringTestUtils { + + private ClusteringTestUtils() { + } + + public static void writePointsToFile(Iterable<VectorWritable> points, + Path path, + FileSystem fs, + Configuration conf) throws IOException { + writePointsToFile(points, false, path, fs, conf); + } + + public static void writePointsToFile(Iterable<VectorWritable> points, + boolean intWritable, + Path path, + FileSystem fs, + Configuration conf) throws IOException { + SequenceFile.Writer writer = new SequenceFile.Writer(fs, + conf, + path, + intWritable ? IntWritable.class : LongWritable.class, + VectorWritable.class); + try { + int recNum = 0; + for (VectorWritable point : points) { + writer.append(intWritable ? new IntWritable(recNum++) : new LongWritable(recNum++), point); + } + } finally { + Closeables.close(writer, false); + } + } + + public static Matrix sampledCorpus(Matrix matrix, Random random, + int numDocs, int numSamples, int numTopicsPerDoc) { + Matrix corpus = new SparseRowMatrix(numDocs, matrix.numCols()); + LDASampler modelSampler = new LDASampler(matrix, random); + Vector topicVector = new DenseVector(matrix.numRows()); + for (int i = 0; i < numTopicsPerDoc; i++) { + int topic = random.nextInt(topicVector.size()); + topicVector.set(topic, topicVector.get(topic) + 1); + } + for (int docId = 0; docId < numDocs; docId++) { + for (int sample : modelSampler.sample(topicVector, numSamples)) { + corpus.set(docId, sample, corpus.get(docId, sample) + 1); + } + } + return corpus; + } + + public static Matrix randomStructuredModel(int numTopics, int numTerms) { + return randomStructuredModel(numTopics, numTerms, new DoubleFunction() { + @Override public double apply(double d) { + return 1.0 / (1 + Math.abs(d)); + } + }); + } + + public static Matrix randomStructuredModel(int numTopics, int numTerms, DoubleFunction decay) { + Matrix model = new DenseMatrix(numTopics, numTerms); + int width = numTerms / numTopics; + for (int topic = 0; topic < numTopics; topic++) { + int topicCentroid = width * (1+topic); + for (int i = 0; i < numTerms; i++) { + int distance = Math.abs(topicCentroid - i); + if (distance > numTerms / 2) { + distance = numTerms - distance; + } + double v = decay.apply(distance); + model.set(topic, i, v); + } + } + return model; + } + + /** + * Takes in a {@link Matrix} of topic distributions (such as generated by {@link org.apache.mahout.clustering.lda.cvb.CVB0Driver} or + * {@link org.apache.mahout.clustering.lda.cvb.InMemoryCollapsedVariationalBayes0}, and constructs + * a set of samplers over this distribution, which may be sampled from by providing a distribution + * over topics, and a number of samples desired + */ + static class LDASampler { + private final Random random; + private final Sampler[] samplers; + + LDASampler(Matrix model, Random random) { + this.random = random; + samplers = new Sampler[model.numRows()]; + for (int i = 0; i < samplers.length; i++) { + samplers[i] = new Sampler(random, model.viewRow(i)); + } + } + + /** + * + * @param topicDistribution vector of p(topicId) for all topicId < model.numTopics() + * @param numSamples the number of times to sample (with replacement) from the model + * @return array of length numSamples, with each entry being a sample from the model. There + * may be repeats + */ + public int[] sample(Vector topicDistribution, int numSamples) { + Preconditions.checkNotNull(topicDistribution); + Preconditions.checkArgument(numSamples > 0, "numSamples must be positive"); + Preconditions.checkArgument(topicDistribution.size() == samplers.length, + "topicDistribution must have same cardinality as the sampling model"); + int[] samples = new int[numSamples]; + Sampler topicSampler = new Sampler(random, topicDistribution); + for (int i = 0; i < numSamples; i++) { + samples[i] = samplers[topicSampler.sample()].sample(); + } + return samples; + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java new file mode 100644 index 0000000..1cbfb02 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestClusterInterface.java @@ -0,0 +1,83 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering; + +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.junit.Test; + +public final class TestClusterInterface extends MahoutTestCase { + + private static final DistanceMeasure measure = new ManhattanDistanceMeasure(); + + @Test + public void testClusterAsFormatString() { + double[] d = { 1.1, 2.2, 3.3 }; + Vector m = new DenseVector(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String formatString = cluster.asFormatString(null); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[1.1,2.2,3.3]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + + @Test + public void testClusterAsFormatStringSparse() { + double[] d = { 1.1, 0.0, 3.3 }; + Vector m = new SequentialAccessSparseVector(3); + m.assign(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String formatString = cluster.asFormatString(null); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + + @Test + public void testClusterAsFormatStringWithBindings() { + double[] d = { 1.1, 2.2, 3.3 }; + Vector m = new DenseVector(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String[] bindings = { "fee", null, "foo" }; + String formatString = cluster.asFormatString(bindings); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[{\"fee\":1.1},{\"1\":2.2},{\"foo\":3.3}]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + + @Test + public void testClusterAsFormatStringSparseWithBindings() { + double[] d = { 1.1, 0.0, 3.3 }; + Vector m = new SequentialAccessSparseVector(3); + m.assign(d); + Cluster cluster = new org.apache.mahout.clustering.kmeans.Kluster(m, 123, measure); + String formatString = cluster.asFormatString(null); + assertTrue(formatString.contains("\"r\":[]")); + assertTrue(formatString.contains("\"c\":[{\"0\":1.1},{\"2\":3.3}]")); + assertTrue(formatString.contains("\"n\":0")); + assertTrue(formatString.contains("\"identifier\":\"CL-123\"")); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java new file mode 100644 index 0000000..43417fc --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/TestGaussianAccumulators.java @@ -0,0 +1,186 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering; + +import java.util.Collection; + +import com.google.common.collect.Lists; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.SquareRootFunction; +import org.junit.Before; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public final class TestGaussianAccumulators extends MahoutTestCase { + + private static final Logger log = LoggerFactory.getLogger(TestGaussianAccumulators.class); + + private Collection<VectorWritable> sampleData = Lists.newArrayList(); + private int sampleN; + private Vector sampleMean; + private Vector sampleStd; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + sampleData = Lists.newArrayList(); + generateSamples(); + sampleN = 0; + Vector sum = new DenseVector(2); + for (VectorWritable v : sampleData) { + sum.assign(v.get(), Functions.PLUS); + sampleN++; + } + sampleMean = sum.divide(sampleN); + + Vector sampleVar = new DenseVector(2); + for (VectorWritable v : sampleData) { + Vector delta = v.get().minus(sampleMean); + sampleVar.assign(delta.times(delta), Functions.PLUS); + } + sampleVar = sampleVar.divide(sampleN - 1); + sampleStd = sampleVar.clone(); + sampleStd.assign(new SquareRootFunction()); + log.info("Observing {} samples m=[{}, {}] sd=[{}, {}]", + sampleN, sampleMean.get(0), sampleMean.get(1), sampleStd.get(0), sampleStd.get(1)); + } + + /** + * Generate random samples and add them to the sampleData + * + * @param num + * int number of samples to generate + * @param mx + * double x-value of the sample mean + * @param my + * double y-value of the sample mean + * @param sdx + * double x-value standard deviation of the samples + * @param sdy + * double y-value standard deviation of the samples + */ + private void generate2dSamples(int num, double mx, double my, double sdx, double sdy) { + log.info("Generating {} samples m=[{}, {}] sd=[{}, {}]", num, mx, my, sdx, sdy); + for (int i = 0; i < num; i++) { + sampleData.add(new VectorWritable(new DenseVector(new double[] { UncommonDistributions.rNorm(mx, sdx), + UncommonDistributions.rNorm(my, sdy) }))); + } + } + + private void generateSamples() { + generate2dSamples(50000, 1, 2, 3, 4); + } + + @Test + public void testAccumulatorNoSamples() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean(), accumulator1.getMean()); + assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON); + } + + @Test + public void testAccumulatorOneSample() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + Vector sample = new DenseVector(2); + accumulator0.observe(sample, 1.0); + accumulator1.observe(sample, 1.0); + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean(), accumulator1.getMean()); + assertEquals("Avg Stds", accumulator0.getAverageStd(), accumulator1.getAverageStd(), EPSILON); + } + + @Test + public void testOLAccumulatorResults() { + GaussianAccumulator accumulator = new OnlineGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator.observe(vw.get(), 1.0); + } + accumulator.compute(); + log.info("OL Observed {} samples m=[{}, {}] sd=[{}, {}]", + accumulator.getN(), + accumulator.getMean().get(0), + accumulator.getMean().get(1), + accumulator.getStd().get(0), + accumulator.getStd().get(1)); + assertEquals("OL N", sampleN, accumulator.getN(), EPSILON); + assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON); + assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), EPSILON); + } + + @Test + public void testRSAccumulatorResults() { + GaussianAccumulator accumulator = new RunningSumsGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator.observe(vw.get(), 1.0); + } + accumulator.compute(); + log.info("RS Observed {} samples m=[{}, {}] sd=[{}, {}]", + (int) accumulator.getN(), + accumulator.getMean().get(0), + accumulator.getMean().get(1), + accumulator.getStd().get(0), + accumulator.getStd().get(1)); + assertEquals("OL N", sampleN, accumulator.getN(), EPSILON); + assertEquals("OL Mean", sampleMean.zSum(), accumulator.getMean().zSum(), EPSILON); + assertEquals("OL Std", sampleStd.zSum(), accumulator.getStd().zSum(), 0.0001); + } + + @Test + public void testAccumulatorWeightedResults() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator0.observe(vw.get(), 0.5); + accumulator1.observe(vw.get(), 0.5); + } + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON); + assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001); + assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01); + } + + @Test + public void testAccumulatorWeightedResults2() { + GaussianAccumulator accumulator0 = new RunningSumsGaussianAccumulator(); + GaussianAccumulator accumulator1 = new OnlineGaussianAccumulator(); + for (VectorWritable vw : sampleData) { + accumulator0.observe(vw.get(), 1.5); + accumulator1.observe(vw.get(), 1.5); + } + accumulator0.compute(); + accumulator1.compute(); + assertEquals("N", accumulator0.getN(), accumulator1.getN(), EPSILON); + assertEquals("Means", accumulator0.getMean().zSum(), accumulator1.getMean().zSum(), EPSILON); + assertEquals("Stds", accumulator0.getStd().zSum(), accumulator1.getStd().zSum(), 0.001); + assertEquals("Variance", accumulator0.getVariance().zSum(), accumulator1.getVariance().zSum(), 0.01); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java new file mode 100644 index 0000000..097fd74 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/canopy/TestCanopyCreation.java @@ -0,0 +1,674 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.canopy; + +import java.util.Collection; +import java.util.List; +import java.util.Set; + +import com.google.common.collect.Iterables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.DummyRecordWriter; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.EuclideanDistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; +import com.google.common.io.Closeables; + +@Deprecated +public final class TestCanopyCreation extends MahoutTestCase { + + private static final double[][] RAW = { { 1, 1 }, { 2, 1 }, { 1, 2 }, + { 2, 2 }, { 3, 3 }, { 4, 4 }, { 5, 4 }, { 4, 5 }, { 5, 5 } }; + + private List<Canopy> referenceManhattan; + + private final DistanceMeasure manhattanDistanceMeasure = new ManhattanDistanceMeasure(); + + private List<Vector> manhattanCentroids; + + private List<Canopy> referenceEuclidean; + + private final DistanceMeasure euclideanDistanceMeasure = new EuclideanDistanceMeasure(); + + private List<Vector> euclideanCentroids; + + private FileSystem fs; + + private static List<VectorWritable> getPointsWritable() { + List<VectorWritable> points = Lists.newArrayList(); + for (double[] fr : RAW) { + Vector vec = new RandomAccessSparseVector(fr.length); + vec.assign(fr); + points.add(new VectorWritable(vec)); + } + return points; + } + + private static List<Vector> getPoints() { + List<Vector> points = Lists.newArrayList(); + for (double[] fr : RAW) { + Vector vec = new RandomAccessSparseVector(fr.length); + vec.assign(fr); + points.add(vec); + } + return points; + } + + /** + * Print the canopies to the transcript + * + * @param canopies + * a List<Canopy> + */ + private static void printCanopies(Iterable<Canopy> canopies) { + for (Canopy canopy : canopies) { + System.out.println(canopy.asFormatString(null)); + } + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + fs = FileSystem.get(getConfiguration()); + referenceManhattan = CanopyClusterer.createCanopies(getPoints(), + manhattanDistanceMeasure, 3.1, 2.1); + manhattanCentroids = CanopyClusterer.getCenters(referenceManhattan); + referenceEuclidean = CanopyClusterer.createCanopies(getPoints(), + euclideanDistanceMeasure, 3.1, 2.1); + euclideanCentroids = CanopyClusterer.getCenters(referenceEuclidean); + } + + /** + * Story: User can cluster points using a ManhattanDistanceMeasure and a + * reference implementation + */ + @Test + public void testReferenceManhattan() throws Exception { + // see setUp for cluster creation + printCanopies(referenceManhattan); + assertEquals("number of canopies", 3, referenceManhattan.size()); + for (int canopyIx = 0; canopyIx < referenceManhattan.size(); canopyIx++) { + Canopy testCanopy = referenceManhattan.get(canopyIx); + int[] expectedNumPoints = { 4, 4, 3 }; + double[][] expectedCentroids = { { 1.5, 1.5 }, { 4.0, 4.0 }, + { 4.666666666666667, 4.6666666666666667 } }; + assertEquals("canopy points " + canopyIx, testCanopy.getNumObservations(), + expectedNumPoints[canopyIx]); + double[] refCentroid = expectedCentroids[canopyIx]; + Vector testCentroid = testCanopy.computeCentroid(); + for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) { + assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']', + refCentroid[pointIx], testCentroid.get(pointIx), EPSILON); + } + } + } + + /** + * Story: User can cluster points using a EuclideanDistanceMeasure and a + * reference implementation + */ + @Test + public void testReferenceEuclidean() throws Exception { + // see setUp for cluster creation + printCanopies(referenceEuclidean); + assertEquals("number of canopies", 3, referenceEuclidean.size()); + int[] expectedNumPoints = { 5, 5, 3 }; + double[][] expectedCentroids = { { 1.8, 1.8 }, { 4.2, 4.2 }, + { 4.666666666666667, 4.666666666666667 } }; + for (int canopyIx = 0; canopyIx < referenceEuclidean.size(); canopyIx++) { + Canopy testCanopy = referenceEuclidean.get(canopyIx); + assertEquals("canopy points " + canopyIx, testCanopy.getNumObservations(), + expectedNumPoints[canopyIx]); + double[] refCentroid = expectedCentroids[canopyIx]; + Vector testCentroid = testCanopy.computeCentroid(); + for (int pointIx = 0; pointIx < refCentroid.length; pointIx++) { + assertEquals("canopy centroid " + canopyIx + '[' + pointIx + ']', + refCentroid[pointIx], testCentroid.get(pointIx), EPSILON); + } + } + } + + /** + * Story: User can produce initial canopy centers using a + * ManhattanDistanceMeasure and a CanopyMapper which clusters input points to + * produce an output set of canopy centroid points. + */ + @Test + public void testCanopyMapperManhattan() throws Exception { + CanopyMapper mapper = new CanopyMapper(); + Configuration conf = getConfiguration(); + conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, manhattanDistanceMeasure + .getClass().getName()); + conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1)); + conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1)); + conf.set(CanopyConfigKeys.CF_KEY, "0"); + DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<>(); + Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter + .build(mapper, conf, writer); + mapper.setup(context); + + List<VectorWritable> points = getPointsWritable(); + // map the data + for (VectorWritable point : points) { + mapper.map(new Text(), point, context); + } + mapper.cleanup(context); + assertEquals("Number of map results", 1, writer.getData().size()); + // now verify the output + List<VectorWritable> data = writer.getValue(new Text("centroid")); + assertEquals("Number of centroids", 3, data.size()); + for (int i = 0; i < data.size(); i++) { + assertEquals("Centroid error", + manhattanCentroids.get(i).asFormatString(), data.get(i).get() + .asFormatString()); + } + } + + /** + * Story: User can produce initial canopy centers using a + * EuclideanDistanceMeasure and a CanopyMapper/Combiner which clusters input + * points to produce an output set of canopy centroid points. + */ + @Test + public void testCanopyMapperEuclidean() throws Exception { + CanopyMapper mapper = new CanopyMapper(); + Configuration conf = getConfiguration(); + conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, euclideanDistanceMeasure + .getClass().getName()); + conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1)); + conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1)); + conf.set(CanopyConfigKeys.CF_KEY, "0"); + DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<>(); + Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter + .build(mapper, conf, writer); + mapper.setup(context); + + List<VectorWritable> points = getPointsWritable(); + // map the data + for (VectorWritable point : points) { + mapper.map(new Text(), point, context); + } + mapper.cleanup(context); + assertEquals("Number of map results", 1, writer.getData().size()); + // now verify the output + List<VectorWritable> data = writer.getValue(new Text("centroid")); + assertEquals("Number of centroids", 3, data.size()); + for (int i = 0; i < data.size(); i++) { + assertEquals("Centroid error", + euclideanCentroids.get(i).asFormatString(), data.get(i).get() + .asFormatString()); + } + } + + /** + * Story: User can produce final canopy centers using a + * ManhattanDistanceMeasure and a CanopyReducer which clusters input centroid + * points to produce an output set of final canopy centroid points. + */ + @Test + public void testCanopyReducerManhattan() throws Exception { + CanopyReducer reducer = new CanopyReducer(); + Configuration conf = getConfiguration(); + conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, + "org.apache.mahout.common.distance.ManhattanDistanceMeasure"); + conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1)); + conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1)); + conf.set(CanopyConfigKeys.CF_KEY, "0"); + DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>(); + Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter + .build(reducer, conf, writer, Text.class, VectorWritable.class); + reducer.setup(context); + + List<VectorWritable> points = getPointsWritable(); + reducer.reduce(new Text("centroid"), points, context); + Iterable<Text> keys = writer.getKeysInInsertionOrder(); + assertEquals("Number of centroids", 3, Iterables.size(keys)); + int i = 0; + for (Text key : keys) { + List<ClusterWritable> data = writer.getValue(key); + ClusterWritable clusterWritable = data.get(0); + Canopy canopy = (Canopy) clusterWritable.getValue(); + assertEquals(manhattanCentroids.get(i).asFormatString() + " is not equal to " + + canopy.computeCentroid().asFormatString(), + manhattanCentroids.get(i), canopy.computeCentroid()); + i++; + } + } + + /** + * Story: User can produce final canopy centers using a + * EuclideanDistanceMeasure and a CanopyReducer which clusters input centroid + * points to produce an output set of final canopy centroid points. + */ + @Test + public void testCanopyReducerEuclidean() throws Exception { + CanopyReducer reducer = new CanopyReducer(); + Configuration conf = getConfiguration(); + conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, "org.apache.mahout.common.distance.EuclideanDistanceMeasure"); + conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1)); + conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1)); + conf.set(CanopyConfigKeys.CF_KEY, "0"); + DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>(); + Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = + DummyRecordWriter.build(reducer, conf, writer, Text.class, VectorWritable.class); + reducer.setup(context); + + List<VectorWritable> points = getPointsWritable(); + reducer.reduce(new Text("centroid"), points, context); + Iterable<Text> keys = writer.getKeysInInsertionOrder(); + assertEquals("Number of centroids", 3, Iterables.size(keys)); + int i = 0; + for (Text key : keys) { + List<ClusterWritable> data = writer.getValue(key); + ClusterWritable clusterWritable = data.get(0); + Canopy canopy = (Canopy) clusterWritable.getValue(); + assertEquals(euclideanCentroids.get(i).asFormatString() + " is not equal to " + + canopy.computeCentroid().asFormatString(), + euclideanCentroids.get(i), canopy.computeCentroid()); + i++; + } + } + + /** + * Story: User can produce final canopy centers using a Hadoop map/reduce job + * and a ManhattanDistanceMeasure. + */ + @Test + public void testCanopyGenManhattanMR() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration config = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, + getTestTempFilePath("testdata/file1"), fs, config); + ClusteringTestUtils.writePointsToFile(points, + getTestTempFilePath("testdata/file2"), fs, config); + // now run the Canopy Driver + Path output = getTestTempDirPath("output"); + CanopyDriver.run(config, getTestTempDirPath("testdata"), output, + manhattanDistanceMeasure, 3.1, 2.1, false, 0.0, false); + + // verify output from sequence file + Path path = new Path(output, "clusters-0-final/part-r-00000"); + FileSystem fs = FileSystem.get(path.toUri(), config); + SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config); + try { + Writable key = new Text(); + ClusterWritable clusterWritable = new ClusterWritable(); + assertTrue("more to come", reader.next(key, clusterWritable)); + assertEquals("1st key", "C-0", key.toString()); + + List<Pair<Double,Double>> refCenters = Lists.newArrayList(); + refCenters.add(new Pair<>(1.5,1.5)); + refCenters.add(new Pair<>(4.333333333333334,4.333333333333334)); + Pair<Double,Double> c = new Pair<>(clusterWritable.getValue() .getCenter().get(0), + clusterWritable.getValue().getCenter().get(1)); + assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON)); + assertTrue("more to come", reader.next(key, clusterWritable)); + assertEquals("2nd key", "C-1", key.toString()); + c = new Pair<>(clusterWritable.getValue().getCenter().get(0), + clusterWritable.getValue().getCenter().get(1)); + assertTrue("center " + c + " not found", findAndRemove(c, refCenters, EPSILON)); + assertFalse("more to come", reader.next(key, clusterWritable)); + } finally { + Closeables.close(reader, true); + } + } + + static boolean findAndRemove(Pair<Double, Double> target, Collection<Pair<Double, Double>> list, double epsilon) { + for (Pair<Double,Double> curr : list) { + if ( (Math.abs(target.getFirst() - curr.getFirst()) < epsilon) + && (Math.abs(target.getSecond() - curr.getSecond()) < epsilon) ) { + list.remove(curr); + return true; + } + } + return false; + } + + /** + * Story: User can produce final canopy centers using a Hadoop map/reduce job + * and a EuclideanDistanceMeasure. + */ + @Test + public void testCanopyGenEuclideanMR() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration config = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, + getTestTempFilePath("testdata/file1"), fs, config); + ClusteringTestUtils.writePointsToFile(points, + getTestTempFilePath("testdata/file2"), fs, config); + // now run the Canopy Driver + Path output = getTestTempDirPath("output"); + CanopyDriver.run(config, getTestTempDirPath("testdata"), output, + euclideanDistanceMeasure, 3.1, 2.1, false, 0.0, false); + + // verify output from sequence file + Path path = new Path(output, "clusters-0-final/part-r-00000"); + FileSystem fs = FileSystem.get(path.toUri(), config); + SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config); + try { + Writable key = new Text(); + ClusterWritable clusterWritable = new ClusterWritable(); + assertTrue("more to come", reader.next(key, clusterWritable)); + assertEquals("1st key", "C-0", key.toString()); + + List<Pair<Double,Double>> refCenters = Lists.newArrayList(); + refCenters.add(new Pair<>(1.8,1.8)); + refCenters.add(new Pair<>(4.433333333333334, 4.433333333333334)); + Pair<Double,Double> c = new Pair<>(clusterWritable.getValue().getCenter().get(0), + clusterWritable.getValue().getCenter().get(1)); + assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON)); + assertTrue("more to come", reader.next(key, clusterWritable)); + assertEquals("2nd key", "C-1", key.toString()); + c = new Pair<>(clusterWritable.getValue().getCenter().get(0), + clusterWritable.getValue().getCenter().get(1)); + assertTrue("center "+c+" not found", findAndRemove(c, refCenters, EPSILON)); + assertFalse("more to come", reader.next(key, clusterWritable)); + } finally { + Closeables.close(reader, true); + } + } + + /** Story: User can cluster points using sequential execution */ + @Test + public void testClusteringManhattanSeq() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration config = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, + getTestTempFilePath("testdata/file1"), fs, config); + // now run the Canopy Driver in sequential mode + Path output = getTestTempDirPath("output"); + CanopyDriver.run(config, getTestTempDirPath("testdata"), output, + manhattanDistanceMeasure, 3.1, 2.1, true, 0.0, true); + + // verify output from sequence file + Path path = new Path(output, "clusters-0-final/part-r-00000"); + int ix = 0; + for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true, + config)) { + assertEquals("Center [" + ix + ']', manhattanCentroids.get(ix), clusterWritable.getValue() + .getCenter()); + ix++; + } + + path = new Path(output, "clusteredPoints/part-m-0"); + long count = HadoopUtil.countRecords(path, config); + assertEquals("number of points", points.size(), count); + } + + /** Story: User can cluster points using sequential execution */ + @Test + public void testClusteringEuclideanSeq() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration config = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, + getTestTempFilePath("testdata/file1"), fs, config); + // now run the Canopy Driver in sequential mode + Path output = getTestTempDirPath("output"); + String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), + getTestTempDirPath("testdata").toString(), + optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(), + optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), + EuclideanDistanceMeasure.class.getName(), + optKey(DefaultOptionCreator.T1_OPTION), "3.1", + optKey(DefaultOptionCreator.T2_OPTION), "2.1", + optKey(DefaultOptionCreator.CLUSTERING_OPTION), + optKey(DefaultOptionCreator.OVERWRITE_OPTION), + optKey(DefaultOptionCreator.METHOD_OPTION), + DefaultOptionCreator.SEQUENTIAL_METHOD }; + ToolRunner.run(config, new CanopyDriver(), args); + + // verify output from sequence file + Path path = new Path(output, "clusters-0-final/part-r-00000"); + + int ix = 0; + for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true, + config)) { + assertEquals("Center [" + ix + ']', euclideanCentroids.get(ix), clusterWritable.getValue() + .getCenter()); + ix++; + } + + path = new Path(output, "clusteredPoints/part-m-0"); + long count = HadoopUtil.countRecords(path, config); + assertEquals("number of points", points.size(), count); + } + + /** Story: User can remove outliers while clustering points using sequential execution */ + @Test + public void testClusteringEuclideanWithOutlierRemovalSeq() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration config = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, + getTestTempFilePath("testdata/file1"), fs, config); + // now run the Canopy Driver in sequential mode + Path output = getTestTempDirPath("output"); + String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), + getTestTempDirPath("testdata").toString(), + optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(), + optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), + EuclideanDistanceMeasure.class.getName(), + optKey(DefaultOptionCreator.T1_OPTION), "3.1", + optKey(DefaultOptionCreator.T2_OPTION), "2.1", + optKey(DefaultOptionCreator.OUTLIER_THRESHOLD), "0.5", + optKey(DefaultOptionCreator.CLUSTERING_OPTION), + optKey(DefaultOptionCreator.OVERWRITE_OPTION), + optKey(DefaultOptionCreator.METHOD_OPTION), + DefaultOptionCreator.SEQUENTIAL_METHOD }; + ToolRunner.run(config, new CanopyDriver(), args); + + // verify output from sequence file + Path path = new Path(output, "clusters-0-final/part-r-00000"); + + int ix = 0; + for (ClusterWritable clusterWritable : new SequenceFileValueIterable<ClusterWritable>(path, true, + config)) { + assertEquals("Center [" + ix + ']', euclideanCentroids.get(ix), clusterWritable.getValue() + .getCenter()); + ix++; + } + + path = new Path(output, "clusteredPoints/part-m-0"); + long count = HadoopUtil.countRecords(path, config); + int expectedPointsHavingPDFGreaterThanThreshold = 6; + assertEquals("number of points", expectedPointsHavingPDFGreaterThanThreshold, count); + } + + + /** + * Story: User can produce final point clustering using a Hadoop map/reduce + * job and a ManhattanDistanceMeasure. + */ + @Test + public void testClusteringManhattanMR() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, true, + getTestTempFilePath("testdata/file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(points, true, + getTestTempFilePath("testdata/file2"), fs, conf); + // now run the Job + Path output = getTestTempDirPath("output"); + CanopyDriver.run(conf, getTestTempDirPath("testdata"), output, + manhattanDistanceMeasure, 3.1, 2.1, true, 0.0, false); + Path path = new Path(output, "clusteredPoints/part-m-00000"); + long count = HadoopUtil.countRecords(path, conf); + assertEquals("number of points", points.size(), count); + } + + /** + * Story: User can produce final point clustering using a Hadoop map/reduce + * job and a EuclideanDistanceMeasure. + */ + @Test + public void testClusteringEuclideanMR() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, true, + getTestTempFilePath("testdata/file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(points, true, + getTestTempFilePath("testdata/file2"), fs, conf); + // now run the Job using the run() command. Others can use runJob(). + Path output = getTestTempDirPath("output"); + String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), + getTestTempDirPath("testdata").toString(), + optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(), + optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), + EuclideanDistanceMeasure.class.getName(), + optKey(DefaultOptionCreator.T1_OPTION), "3.1", + optKey(DefaultOptionCreator.T2_OPTION), "2.1", + optKey(DefaultOptionCreator.CLUSTERING_OPTION), + optKey(DefaultOptionCreator.OVERWRITE_OPTION) }; + ToolRunner.run(getConfiguration(), new CanopyDriver(), args); + Path path = new Path(output, "clusteredPoints/part-m-00000"); + long count = HadoopUtil.countRecords(path, conf); + assertEquals("number of points", points.size(), count); + } + + /** + * Story: User can produce final point clustering using a Hadoop map/reduce + * job and a EuclideanDistanceMeasure and outlier removal threshold. + */ + @Test + public void testClusteringEuclideanWithOutlierRemovalMR() throws Exception { + List<VectorWritable> points = getPointsWritable(); + Configuration conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, true, + getTestTempFilePath("testdata/file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(points, true, + getTestTempFilePath("testdata/file2"), fs, conf); + // now run the Job using the run() command. Others can use runJob(). + Path output = getTestTempDirPath("output"); + String[] args = { optKey(DefaultOptionCreator.INPUT_OPTION), + getTestTempDirPath("testdata").toString(), + optKey(DefaultOptionCreator.OUTPUT_OPTION), output.toString(), + optKey(DefaultOptionCreator.DISTANCE_MEASURE_OPTION), + EuclideanDistanceMeasure.class.getName(), + optKey(DefaultOptionCreator.T1_OPTION), "3.1", + optKey(DefaultOptionCreator.T2_OPTION), "2.1", + optKey(DefaultOptionCreator.OUTLIER_THRESHOLD), "0.7", + optKey(DefaultOptionCreator.CLUSTERING_OPTION), + optKey(DefaultOptionCreator.OVERWRITE_OPTION) }; + ToolRunner.run(getConfiguration(), new CanopyDriver(), args); + Path path = new Path(output, "clusteredPoints/part-m-00000"); + long count = HadoopUtil.countRecords(path, conf); + int expectedPointsAfterOutlierRemoval = 8; + assertEquals("number of points", expectedPointsAfterOutlierRemoval, count); + } + + + /** + * Story: User can set T3 and T4 values to be used by the reducer for its T1 + * and T2 thresholds + */ + @Test + public void testCanopyReducerT3T4Configuration() throws Exception { + CanopyReducer reducer = new CanopyReducer(); + Configuration conf = getConfiguration(); + conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, + "org.apache.mahout.common.distance.ManhattanDistanceMeasure"); + conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1)); + conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1)); + conf.set(CanopyConfigKeys.T3_KEY, String.valueOf(1.1)); + conf.set(CanopyConfigKeys.T4_KEY, String.valueOf(0.1)); + conf.set(CanopyConfigKeys.CF_KEY, "0"); + DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>(); + Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter + .build(reducer, conf, writer, Text.class, VectorWritable.class); + reducer.setup(context); + assertEquals(1.1, reducer.getCanopyClusterer().getT1(), EPSILON); + assertEquals(0.1, reducer.getCanopyClusterer().getT2(), EPSILON); + } + + /** + * Story: User can specify a clustering limit that prevents output of small + * clusters + */ + @Test + public void testCanopyMapperClusterFilter() throws Exception { + CanopyMapper mapper = new CanopyMapper(); + Configuration conf = getConfiguration(); + conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, manhattanDistanceMeasure + .getClass().getName()); + conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1)); + conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1)); + conf.set(CanopyConfigKeys.CF_KEY, "3"); + DummyRecordWriter<Text, VectorWritable> writer = new DummyRecordWriter<>(); + Mapper<WritableComparable<?>, VectorWritable, Text, VectorWritable>.Context context = DummyRecordWriter + .build(mapper, conf, writer); + mapper.setup(context); + + List<VectorWritable> points = getPointsWritable(); + // map the data + for (VectorWritable point : points) { + mapper.map(new Text(), point, context); + } + mapper.cleanup(context); + assertEquals("Number of map results", 1, writer.getData().size()); + // now verify the output + List<VectorWritable> data = writer.getValue(new Text("centroid")); + assertEquals("Number of centroids", 2, data.size()); + } + + /** + * Story: User can specify a cluster filter that limits the minimum size of + * canopies produced by the reducer + */ + @Test + public void testCanopyReducerClusterFilter() throws Exception { + CanopyReducer reducer = new CanopyReducer(); + Configuration conf = getConfiguration(); + conf.set(CanopyConfigKeys.DISTANCE_MEASURE_KEY, + "org.apache.mahout.common.distance.ManhattanDistanceMeasure"); + conf.set(CanopyConfigKeys.T1_KEY, String.valueOf(3.1)); + conf.set(CanopyConfigKeys.T2_KEY, String.valueOf(2.1)); + conf.set(CanopyConfigKeys.CF_KEY, "3"); + DummyRecordWriter<Text, ClusterWritable> writer = new DummyRecordWriter<>(); + Reducer<Text, VectorWritable, Text, ClusterWritable>.Context context = DummyRecordWriter + .build(reducer, conf, writer, Text.class, VectorWritable.class); + reducer.setup(context); + + List<VectorWritable> points = getPointsWritable(); + reducer.reduce(new Text("centroid"), points, context); + Set<Text> keys = writer.getKeys(); + assertEquals("Number of centroids", 2, keys.size()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java new file mode 100644 index 0000000..cbf0e55 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java @@ -0,0 +1,255 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.classify; + +import java.io.IOException; +import java.util.List; +import java.util.Set; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.clustering.canopy.CanopyDriver; +import org.apache.mahout.clustering.iterator.CanopyClusteringPolicy; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +public class ClusterClassificationDriverTest extends MahoutTestCase { + + private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, + {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}}; + + private FileSystem fs; + private Path clusteringOutputPath; + private Configuration conf; + private Path pointsPath; + private Path classifiedOutputPath; + private List<Vector> firstCluster; + private List<Vector> secondCluster; + private List<Vector> thirdCluster; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + Configuration conf = getConfiguration(); + fs = FileSystem.get(conf); + firstCluster = Lists.newArrayList(); + secondCluster = Lists.newArrayList(); + thirdCluster = Lists.newArrayList(); + + } + + private static List<VectorWritable> getPointsWritable(double[][] raw) { + List<VectorWritable> points = Lists.newArrayList(); + for (double[] fr : raw) { + Vector vec = new RandomAccessSparseVector(fr.length); + vec.assign(fr); + points.add(new VectorWritable(vec)); + } + return points; + } + + @Test + public void testVectorClassificationWithOutlierRemovalMR() throws Exception { + List<VectorWritable> points = getPointsWritable(REFERENCE); + + pointsPath = getTestTempDirPath("points"); + clusteringOutputPath = getTestTempDirPath("output"); + classifiedOutputPath = getTestTempDirPath("classifiedClusters"); + HadoopUtil.delete(conf, classifiedOutputPath); + + conf = getConfiguration(); + + ClusteringTestUtils.writePointsToFile(points, true, + new Path(pointsPath, "file1"), fs, conf); + runClustering(pointsPath, conf, false); + runClassificationWithOutlierRemoval(false); + collectVectorsForAssertion(); + assertVectorsWithOutlierRemoval(); + } + + @Test + public void testVectorClassificationWithoutOutlierRemoval() throws Exception { + List<VectorWritable> points = getPointsWritable(REFERENCE); + + pointsPath = getTestTempDirPath("points"); + clusteringOutputPath = getTestTempDirPath("output"); + classifiedOutputPath = getTestTempDirPath("classify"); + + conf = getConfiguration(); + + ClusteringTestUtils.writePointsToFile(points, + new Path(pointsPath, "file1"), fs, conf); + runClustering(pointsPath, conf, true); + runClassificationWithoutOutlierRemoval(); + collectVectorsForAssertion(); + assertVectorsWithoutOutlierRemoval(); + } + + @Test + public void testVectorClassificationWithOutlierRemoval() throws Exception { + List<VectorWritable> points = getPointsWritable(REFERENCE); + + pointsPath = getTestTempDirPath("points"); + clusteringOutputPath = getTestTempDirPath("output"); + classifiedOutputPath = getTestTempDirPath("classify"); + + conf = getConfiguration(); + + ClusteringTestUtils.writePointsToFile(points, + new Path(pointsPath, "file1"), fs, conf); + runClustering(pointsPath, conf, true); + runClassificationWithOutlierRemoval(true); + collectVectorsForAssertion(); + assertVectorsWithOutlierRemoval(); + } + + private void runClustering(Path pointsPath, Configuration conf, + Boolean runSequential) throws IOException, InterruptedException, + ClassNotFoundException { + CanopyDriver.run(conf, pointsPath, clusteringOutputPath, + new ManhattanDistanceMeasure(), 3.1, 2.1, false, 0.0, runSequential); + Path finalClustersPath = new Path(clusteringOutputPath, "clusters-0-final"); + ClusterClassifier.writePolicy(new CanopyClusteringPolicy(), + finalClustersPath); + } + + private void runClassificationWithoutOutlierRemoval() + throws IOException, InterruptedException, ClassNotFoundException { + ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.0, true, true); + } + + private void runClassificationWithOutlierRemoval(boolean runSequential) + throws IOException, InterruptedException, ClassNotFoundException { + ClusterClassificationDriver.run(getConfiguration(), pointsPath, clusteringOutputPath, classifiedOutputPath, 0.73, true, runSequential); + } + + private void collectVectorsForAssertion() throws IOException { + Path[] partFilePaths = FileUtil.stat2Paths(fs + .globStatus(classifiedOutputPath)); + FileStatus[] listStatus = fs.listStatus(partFilePaths, + PathFilters.partFilter()); + for (FileStatus partFile : listStatus) { + SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs, + partFile.getPath(), conf); + Writable clusterIdAsKey = new IntWritable(); + WeightedPropertyVectorWritable point = new WeightedPropertyVectorWritable(); + while (classifiedVectors.next(clusterIdAsKey, point)) { + collectVector(clusterIdAsKey.toString(), point.getVector()); + } + } + } + + private void collectVector(String clusterId, Vector vector) { + if ("0".equals(clusterId)) { + firstCluster.add(vector); + } else if ("1".equals(clusterId)) { + secondCluster.add(vector); + } else if ("2".equals(clusterId)) { + thirdCluster.add(vector); + } + } + + private void assertVectorsWithOutlierRemoval() { + checkClustersWithOutlierRemoval(); + } + + private void assertVectorsWithoutOutlierRemoval() { + assertFirstClusterWithoutOutlierRemoval(); + assertSecondClusterWithoutOutlierRemoval(); + assertThirdClusterWithoutOutlierRemoval(); + } + + private void assertThirdClusterWithoutOutlierRemoval() { + Assert.assertEquals(2, thirdCluster.size()); + for (Vector vector : thirdCluster) { + Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:9.0,1:9.0}", + "{0:8.0,1:8.0}"}, vector.asFormatString())); + } + } + + private void assertSecondClusterWithoutOutlierRemoval() { + Assert.assertEquals(4, secondCluster.size()); + for (Vector vector : secondCluster) { + Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:4.0,1:4.0}", + "{0:5.0,1:4.0}", "{0:4.0,1:5.0}", "{0:5.0,1:5.0}"}, + vector.asFormatString())); + } + } + + private void assertFirstClusterWithoutOutlierRemoval() { + Assert.assertEquals(3, firstCluster.size()); + for (Vector vector : firstCluster) { + Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:1.0,1:1.0}", + "{0:2.0,1:1.0}", "{0:1.0,1:2.0}"}, vector.asFormatString())); + } + } + + private void checkClustersWithOutlierRemoval() { + Set<String> reference = Sets.newHashSet("{0:9.0,1:9.0}", "{0:1.0,1:1.0}"); + + List<List<Vector>> clusters = Lists.newArrayList(); + clusters.add(firstCluster); + clusters.add(secondCluster); + clusters.add(thirdCluster); + + int singletonCnt = 0; + int emptyCnt = 0; + for (List<Vector> vList : clusters) { + if (vList.isEmpty()) { + emptyCnt++; + } else { + singletonCnt++; + assertEquals("expecting only singleton clusters; got size=" + vList.size(), 1, vList.size()); + if (vList.get(0).getClass().equals(NamedVector.class)) { + Assert.assertTrue("not expecting cluster:" + ((NamedVector) vList.get(0)).getDelegate().asFormatString(), + reference.contains(((NamedVector) vList.get(0)).getDelegate().asFormatString())); + reference.remove(((NamedVector)vList.get(0)).getDelegate().asFormatString()); + } else if (vList.get(0).getClass().equals(RandomAccessSparseVector.class)) { + Assert.assertTrue("not expecting cluster:" + vList.get(0).asFormatString(), + reference.contains(vList.get(0).asFormatString())); + reference.remove(vList.get(0).asFormatString()); + } + } + } + Assert.assertEquals("Different number of empty clusters than expected!", 1, emptyCnt); + Assert.assertEquals("Different number of singletons than expected!", 2, singletonCnt); + Assert.assertEquals("Didn't match all reference clusters!", 0, reference.size()); + } + +}
