http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java new file mode 100644 index 0000000..340ca8e --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/BallKMeansTest.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.cluster; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.EuclideanDistanceMeasure; +import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.ConstantVector; +import org.apache.mahout.math.DenseMatrix; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.SingularValueDecomposition; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.function.VectorFunction; +import org.apache.mahout.math.neighborhood.BruteSearch; +import org.apache.mahout.math.neighborhood.Searcher; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.apache.mahout.math.random.MultiNormal; +import org.apache.mahout.math.random.WeightedThing; +import org.apache.mahout.math.stats.OnlineSummarizer; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import static org.apache.mahout.clustering.ClusteringUtils.totalWeight; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class BallKMeansTest { + private static final int NUM_DATA_POINTS = 10000; + private static final int NUM_DIMENSIONS = 4; + private static final int NUM_ITERATIONS = 20; + private static final double DISTRIBUTION_RADIUS = 0.01; + + @BeforeClass + public static void setUp() { + RandomUtils.useTestSeed(); + syntheticData = DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, DISTRIBUTION_RADIUS); + + } + + private static Pair<List<Centroid>, List<Centroid>> syntheticData; + private static final int K1 = 100; + + + @Test + public void testClusteringMultipleRuns() { + for (int i = 1; i <= 10; ++i) { + BallKMeans clusterer = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()), + 1 << NUM_DIMENSIONS, NUM_ITERATIONS, true, i); + clusterer.cluster(syntheticData.getFirst()); + double costKMeansPlusPlus = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), clusterer); + + clusterer = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()), + 1 << NUM_DIMENSIONS, NUM_ITERATIONS, false, i); + clusterer.cluster(syntheticData.getFirst()); + double costKMeansRandom = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), clusterer); + + System.out.printf("%d runs; kmeans++: %f; random: %f\n", i, costKMeansPlusPlus, costKMeansRandom); + assertTrue("kmeans++ cost should be less than random cost", costKMeansPlusPlus < costKMeansRandom); + } + } + + @Test + public void testClustering() { + UpdatableSearcher searcher = new BruteSearch(new SquaredEuclideanDistanceMeasure()); + BallKMeans clusterer = new BallKMeans(searcher, 1 << NUM_DIMENSIONS, NUM_ITERATIONS); + + long startTime = System.currentTimeMillis(); + Pair<List<Centroid>, List<Centroid>> data = syntheticData; + clusterer.cluster(data.getFirst()); + long endTime = System.currentTimeMillis(); + + long hash = 0; + for (Centroid centroid : data.getFirst()) { + for (Vector.Element element : centroid.all()) { + hash = 31 * hash + 17 * element.index() + Double.toHexString(element.get()).hashCode(); + } + } + System.out.printf("Hash = %08x\n", hash); + + assertEquals("Total weight not preserved", totalWeight(syntheticData.getFirst()), totalWeight(clusterer), 1.0e-9); + + // Verify that each corner of the cube has a centroid very nearby. + // This is probably FALSE for large-dimensional spaces! + OnlineSummarizer summarizer = new OnlineSummarizer(); + for (Vector mean : syntheticData.getSecond()) { + WeightedThing<Vector> v = searcher.search(mean, 1).get(0); + summarizer.add(v.getWeight()); + } + assertTrue(String.format("Median weight [%f] too large [>%f]", summarizer.getMedian(), + DISTRIBUTION_RADIUS), summarizer.getMedian() < DISTRIBUTION_RADIUS); + + double clusterTime = (endTime - startTime) / 1000.0; + System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n", + searcher.getClass().getName(), clusterTime, + clusterTime / syntheticData.getFirst().size() * 1.0e6); + + // Verify that the total weight of the centroids near each corner is correct. + double[] cornerWeights = new double[1 << NUM_DIMENSIONS]; + Searcher trueFinder = new BruteSearch(new EuclideanDistanceMeasure()); + for (Vector trueCluster : syntheticData.getSecond()) { + trueFinder.add(trueCluster); + } + for (Centroid centroid : clusterer) { + WeightedThing<Vector> closest = trueFinder.search(centroid, 1).get(0); + cornerWeights[((Centroid)closest.getValue()).getIndex()] += centroid.getWeight(); + } + int expectedNumPoints = NUM_DATA_POINTS / (1 << NUM_DIMENSIONS); + for (double v : cornerWeights) { + System.out.printf("%f ", v); + } + System.out.println(); + for (double v : cornerWeights) { + assertEquals(expectedNumPoints, v, 0); + } + } + + @Test + public void testInitialization() { + // Start with super clusterable data. + List<? extends WeightedVector> data = cubishTestData(0.01); + + // Just do initialization of ball k-means. This should drop a point into each of the clusters. + BallKMeans r = new BallKMeans(new BruteSearch(new SquaredEuclideanDistanceMeasure()), 6, 20); + r.cluster(data); + + // Put the centroids into a matrix. + Matrix x = new DenseMatrix(6, 5); + int row = 0; + for (Centroid c : r) { + x.viewRow(row).assign(c.viewPart(0, 5)); + row++; + } + + // Verify that each column looks right. Should contain zeros except for a single 6. + final Vector columnNorms = x.aggregateColumns(new VectorFunction() { + @Override + public double apply(Vector f) { + // Return the sum of three discrepancy measures. + return Math.abs(f.minValue()) + Math.abs(f.maxValue() - 6) + Math.abs(f.norm(1) - 6); + } + }); + // Verify all errors are nearly zero. + assertEquals(0, columnNorms.norm(1) / columnNorms.size(), 0.1); + + // Verify that the centroids are a permutation of the original ones. + SingularValueDecomposition svd = new SingularValueDecomposition(x); + Vector s = svd.getS().viewDiagonal().assign(Functions.div(6)); + assertEquals(5, s.getLengthSquared(), 0.05); + assertEquals(5, s.norm(1), 0.05); + } + + private static List<? extends WeightedVector> cubishTestData(double radius) { + List<WeightedVector> data = Lists.newArrayListWithCapacity(K1 + 5000); + int row = 0; + + MultiNormal g = new MultiNormal(radius, new ConstantVector(0, 10)); + for (int i = 0; i < K1; i++) { + data.add(new WeightedVector(g.sample(), 1, row++)); + } + + for (int i = 0; i < 5; i++) { + Vector m = new DenseVector(10); + m.set(i, 6); // This was originally i == 0 ? 6 : 6 which can't be right + MultiNormal gx = new MultiNormal(radius, m); + for (int j = 0; j < 1000; j++) { + data.add(new WeightedVector(gx.sample(), 1, row++)); + } + } + return data; + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java new file mode 100644 index 0000000..5a10a55 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/DataUtils.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.cluster; + +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.mahout.common.Pair; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.random.MultiNormal; + +/** + * A collection of miscellaneous utility functions for working with data to be clustered. + * Includes methods for generating synthetic data and estimating distance cutoff. + */ +public final class DataUtils { + private DataUtils() { + } + + /** + * Samples numDatapoints vectors of numDimensions cardinality centered around the vertices of a + * numDimensions order hypercube. The distribution of points around these vertices is + * multinormal with a radius of distributionRadius. + * A hypercube of numDimensions has 2^numDimensions vertices. Keep this in mind when clustering + * the data. + * + * Note that it is almost always the case that you want to call RandomUtils.useTestSeed() before + * generating test data. This means that you can't generate data in the declaration of a static + * variable because such initializations happen before any @BeforeClass or @Before setup methods + * are called. + * + * + * @param numDimensions number of dimensions of the vectors to be generated. + * @param numDatapoints number of data points to be generated. + * @param distributionRadius radius of the distribution around the hypercube vertices. + * @return a pair of lists, whose first element is the sampled points and whose second element + * is the list of hypercube vertices that are the means of each distribution. + */ + public static Pair<List<Centroid>, List<Centroid>> sampleMultiNormalHypercube( + int numDimensions, int numDatapoints, double distributionRadius) { + int pow2N = 1 << numDimensions; + // Construct data samplers centered on the corners of a unit hypercube. + // Additionally, keep the means of the distributions that will be generated so we can compare + // these to the ideal cluster centers. + List<Centroid> mean = Lists.newArrayListWithCapacity(pow2N); + List<MultiNormal> rowSamplers = Lists.newArrayList(); + for (int i = 0; i < pow2N; i++) { + Vector v = new DenseVector(numDimensions); + // Select each of the num + int pow2J = 1 << (numDimensions - 1); + for (int j = 0; j < numDimensions; ++j) { + v.set(j, 1.0 / pow2J * (i & pow2J)); + pow2J >>= 1; + } + mean.add(new Centroid(i, v, 1)); + rowSamplers.add(new MultiNormal(distributionRadius, v)); + } + + // Sample the requested number of data points. + List<Centroid> data = Lists.newArrayListWithCapacity(numDatapoints); + for (int i = 0; i < numDatapoints; ++i) { + data.add(new Centroid(i, rowSamplers.get(i % pow2N).sample(), 1)); + } + return new Pair<>(data, mean); + } + + /** + * Calls sampleMultinormalHypercube(numDimension, numDataPoints, 0.01). + * @see DataUtils#sampleMultiNormalHypercube(int, int, double) + */ + public static Pair<List<Centroid>, List<Centroid>> sampleMultiNormalHypercube(int numDimensions, + int numDatapoints) { + return sampleMultiNormalHypercube(numDimensions, numDatapoints, 0.01); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java new file mode 100644 index 0000000..cf9263c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeansTest.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.cluster; + + +import java.util.Arrays; +import java.util.List; + +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.EuclideanDistanceMeasure; +import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.neighborhood.BruteSearch; +import org.apache.mahout.math.neighborhood.FastProjectionSearch; +import org.apache.mahout.math.neighborhood.ProjectionSearch; +import org.apache.mahout.math.neighborhood.Searcher; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.apache.mahout.math.random.WeightedThing; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.runners.Parameterized.Parameters; + + +@RunWith(Parameterized.class) +public class StreamingKMeansTest { + private static final int NUM_DATA_POINTS = 1 << 16; + private static final int NUM_DIMENSIONS = 6; + private static final int NUM_PROJECTIONS = 2; + private static final int SEARCH_SIZE = 10; + + private static Pair<List<Centroid>, List<Centroid>> syntheticData ; + + @Before + public void setUp() { + RandomUtils.useTestSeed(); + syntheticData = + DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS); + } + + private UpdatableSearcher searcher; + private boolean allAtOnce; + + public StreamingKMeansTest(UpdatableSearcher searcher, boolean allAtOnce) { + this.searcher = searcher; + this.allAtOnce = allAtOnce; + } + + @Parameters + public static List<Object[]> generateData() { + return Arrays.asList(new Object[][] { + {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), true}, + {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), + true}, + {new ProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), false}, + {new FastProjectionSearch(new SquaredEuclideanDistanceMeasure(), NUM_PROJECTIONS, SEARCH_SIZE), + false}, + }); + } + + @Test + public void testAverageDistanceCutoff() { + double avgDistanceCutoff = 0; + double avgNumClusters = 0; + int numTests = 1; + System.out.printf("Distance cutoff for %s\n", searcher.getClass().getName()); + for (int i = 0; i < numTests; ++i) { + searcher.clear(); + int numStreamingClusters = (int)Math.log(syntheticData.getFirst().size()) * (1 << + NUM_DIMENSIONS); + double distanceCutoff = 1.0e-6; + double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff(syntheticData.getFirst(), + searcher.getDistanceMeasure(), 100); + System.out.printf("[%d] Generated synthetic data [magic] %f [estimate] %f\n", i, distanceCutoff, estimatedCutoff); + StreamingKMeans clusterer = + new StreamingKMeans(searcher, numStreamingClusters, estimatedCutoff); + clusterer.cluster(syntheticData.getFirst()); + avgDistanceCutoff += clusterer.getDistanceCutoff(); + avgNumClusters += clusterer.getNumClusters(); + System.out.printf("[%d] %f\n", i, clusterer.getDistanceCutoff()); + } + avgDistanceCutoff /= numTests; + avgNumClusters /= numTests; + System.out.printf("Final: distanceCutoff: %f estNumClusters: %f\n", avgDistanceCutoff, avgNumClusters); + } + + @Test + public void testClustering() { + searcher.clear(); + int numStreamingClusters = (int)Math.log(syntheticData.getFirst().size()) * (1 << NUM_DIMENSIONS); + System.out.printf("k log n = %d\n", numStreamingClusters); + double estimatedCutoff = ClusteringUtils.estimateDistanceCutoff(syntheticData.getFirst(), + searcher.getDistanceMeasure(), 100); + StreamingKMeans clusterer = + new StreamingKMeans(searcher, numStreamingClusters, estimatedCutoff); + + long startTime = System.currentTimeMillis(); + if (allAtOnce) { + clusterer.cluster(syntheticData.getFirst()); + } else { + for (Centroid datapoint : syntheticData.getFirst()) { + clusterer.cluster(datapoint); + } + } + long endTime = System.currentTimeMillis(); + + System.out.printf("%s %s\n", searcher.getClass().getName(), searcher.getDistanceMeasure() + .getClass().getName()); + System.out.printf("Total number of clusters %d\n", clusterer.getNumClusters()); + + System.out.printf("Weights: %f %f\n", ClusteringUtils.totalWeight(syntheticData.getFirst()), + ClusteringUtils.totalWeight(clusterer)); + assertEquals("Total weight not preserved", ClusteringUtils.totalWeight(syntheticData.getFirst()), + ClusteringUtils.totalWeight(clusterer), 1.0e-9); + + // and verify that each corner of the cube has a centroid very nearby + double maxWeight = 0; + for (Vector mean : syntheticData.getSecond()) { + WeightedThing<Vector> v = searcher.search(mean, 1).get(0); + maxWeight = Math.max(v.getWeight(), maxWeight); + } + assertTrue("Maximum weight too large " + maxWeight, maxWeight < 0.05); + double clusterTime = (endTime - startTime) / 1000.0; + System.out.printf("%s\n%.2f for clustering\n%.1f us per row\n\n", + searcher.getClass().getName(), clusterTime, + clusterTime / syntheticData.getFirst().size() * 1.0e6); + + // verify that the total weight of the centroids near each corner is correct + double[] cornerWeights = new double[1 << NUM_DIMENSIONS]; + Searcher trueFinder = new BruteSearch(new EuclideanDistanceMeasure()); + for (Vector trueCluster : syntheticData.getSecond()) { + trueFinder.add(trueCluster); + } + for (Centroid centroid : clusterer) { + WeightedThing<Vector> closest = trueFinder.search(centroid, 1).get(0); + cornerWeights[((Centroid)closest.getValue()).getIndex()] += centroid.getWeight(); + } + int expectedNumPoints = NUM_DATA_POINTS / (1 << NUM_DIMENSIONS); + for (double v : cornerWeights) { + System.out.printf("%f ", v); + } + System.out.println(); + for (double v : cornerWeights) { + assertEquals(expectedNumPoints, v, 0); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java new file mode 100644 index 0000000..dbf05be --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansTestMR.java @@ -0,0 +1,282 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; + +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mrunit.mapreduce.MapDriver; +import org.apache.hadoop.mrunit.mapreduce.MapReduceDriver; +import org.apache.hadoop.mrunit.mapreduce.ReduceDriver; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.clustering.streaming.cluster.DataUtils; +import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.neighborhood.BruteSearch; +import org.apache.mahout.math.neighborhood.FastProjectionSearch; +import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch; +import org.apache.mahout.math.neighborhood.ProjectionSearch; +import org.apache.mahout.math.random.WeightedThing; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +@RunWith(Parameterized.class) +public class StreamingKMeansTestMR extends MahoutTestCase { + private static final int NUM_DATA_POINTS = 1 << 15; + private static final int NUM_DIMENSIONS = 8; + private static final int NUM_PROJECTIONS = 3; + private static final int SEARCH_SIZE = 5; + private static final int MAX_NUM_ITERATIONS = 10; + private static final double DISTANCE_CUTOFF = 1.0e-6; + + private static Pair<List<Centroid>, List<Centroid>> syntheticData; + + @Before + public void setUp() { + RandomUtils.useTestSeed(); + syntheticData = + DataUtils.sampleMultiNormalHypercube(NUM_DIMENSIONS, NUM_DATA_POINTS, 1.0e-4); + } + + private final String searcherClassName; + private final String distanceMeasureClassName; + + public StreamingKMeansTestMR(String searcherClassName, String distanceMeasureClassName) { + this.searcherClassName = searcherClassName; + this.distanceMeasureClassName = distanceMeasureClassName; + } + + private void configure(Configuration configuration) { + configuration.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, distanceMeasureClassName); + configuration.setInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, SEARCH_SIZE); + configuration.setInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, NUM_PROJECTIONS); + configuration.set(StreamingKMeansDriver.SEARCHER_CLASS_OPTION, searcherClassName); + configuration.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1 << NUM_DIMENSIONS); + configuration.setInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, + (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS)); + configuration.setFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, (float) DISTANCE_CUTOFF); + configuration.setInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, MAX_NUM_ITERATIONS); + + // Collapse the Centroids in the reducer. + configuration.setBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, true); + } + + @Parameterized.Parameters + public static List<Object[]> generateData() { + return Arrays.asList(new Object[][]{ + {ProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()}, + {FastProjectionSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()}, + {LocalitySensitiveHashSearch.class.getName(), SquaredEuclideanDistanceMeasure.class.getName()}, + }); + } + + @Test + public void testHypercubeMapper() throws IOException { + MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver = + MapDriver.newMapDriver(new StreamingKMeansMapper()); + configure(mapDriver.getConfiguration()); + System.out.printf("%s mapper test\n", + mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION)); + for (Centroid datapoint : syntheticData.getFirst()) { + mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint)); + } + List<org.apache.hadoop.mrunit.types.Pair<IntWritable,CentroidWritable>> results = mapDriver.run(); + BruteSearch resultSearcher = new BruteSearch(new SquaredEuclideanDistanceMeasure()); + for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) { + resultSearcher.add(result.getSecond().getCentroid()); + } + System.out.printf("Clustered the data into %d clusters\n", results.size()); + for (Vector mean : syntheticData.getSecond()) { + WeightedThing<Vector> closest = resultSearcher.search(mean, 1).get(0); + assertTrue("Weight " + closest.getWeight() + " not less than 0.5", closest.getWeight() < 0.5); + } + } + + @Test + public void testMapperVsLocal() throws IOException { + // Clusters the data using the StreamingKMeansMapper. + MapDriver<Writable, VectorWritable, IntWritable, CentroidWritable> mapDriver = + MapDriver.newMapDriver(new StreamingKMeansMapper()); + Configuration configuration = mapDriver.getConfiguration(); + configure(configuration); + System.out.printf("%s mapper vs local test\n", + mapDriver.getConfiguration().get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION)); + + for (Centroid datapoint : syntheticData.getFirst()) { + mapDriver.addInput(new IntWritable(0), new VectorWritable(datapoint)); + } + List<Centroid> mapperCentroids = Lists.newArrayList(); + for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> pair : mapDriver.run()) { + mapperCentroids.add(pair.getSecond().getCentroid()); + } + + // Clusters the data using local batch StreamingKMeans. + StreamingKMeans batchClusterer = + new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration), + mapDriver.getConfiguration().getInt("estimatedNumMapClusters", -1), DISTANCE_CUTOFF); + batchClusterer.cluster(syntheticData.getFirst()); + List<Centroid> batchCentroids = Lists.newArrayList(); + for (Vector v : batchClusterer) { + batchCentroids.add((Centroid) v); + } + + // Clusters the data using point by point StreamingKMeans. + StreamingKMeans perPointClusterer = + new StreamingKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(configuration), + (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF); + for (Centroid datapoint : syntheticData.getFirst()) { + perPointClusterer.cluster(datapoint); + } + List<Centroid> perPointCentroids = Lists.newArrayList(); + for (Vector v : perPointClusterer) { + perPointCentroids.add((Centroid) v); + } + + // Computes the cost (total sum of distances) of these different clusterings. + double mapperCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), mapperCentroids); + double localCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), batchCentroids); + double perPointCost = ClusteringUtils.totalClusterCost(syntheticData.getFirst(), perPointCentroids); + System.out.printf("[Total cost] Mapper %f [%d] Local %f [%d] Perpoint local %f [%d];" + + "[ratio m-vs-l %f] [ratio pp-vs-l %f]\n", mapperCost, mapperCentroids.size(), + localCost, batchCentroids.size(), perPointCost, perPointCentroids.size(), + mapperCost / localCost, perPointCost / localCost); + + // These ratios should be close to 1.0 and have been observed to be go as low as 0.6 and as low as 1.5. + // A buffer of [0.2, 1.8] seems appropriate. + assertEquals("Mapper StreamingKMeans / Batch local StreamingKMeans total cost ratio too far from 1", + 1.0, mapperCost / localCost, 0.8); + assertEquals("One by one local StreamingKMeans / Batch local StreamingKMeans total cost ratio too high", + 1.0, perPointCost / localCost, 0.8); + } + + @Test + public void testHypercubeReducer() throws IOException { + ReduceDriver<IntWritable, CentroidWritable, IntWritable, CentroidWritable> reduceDriver = + ReduceDriver.newReduceDriver(new StreamingKMeansReducer()); + Configuration configuration = reduceDriver.getConfiguration(); + configure(configuration); + + System.out.printf("%s reducer test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION)); + StreamingKMeans clusterer = + new StreamingKMeans(StreamingKMeansUtilsMR .searcherFromConfiguration(configuration), + (1 << NUM_DIMENSIONS) * (int)Math.log(NUM_DATA_POINTS), DISTANCE_CUTOFF); + + long start = System.currentTimeMillis(); + clusterer.cluster(syntheticData.getFirst()); + long end = System.currentTimeMillis(); + + System.out.printf("%f [s]\n", (end - start) / 1000.0); + List<CentroidWritable> reducerInputs = Lists.newArrayList(); + int postMapperTotalWeight = 0; + for (Centroid intermediateCentroid : clusterer) { + reducerInputs.add(new CentroidWritable(intermediateCentroid)); + postMapperTotalWeight += intermediateCentroid.getWeight(); + } + + reduceDriver.addInput(new IntWritable(0), reducerInputs); + List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results = + reduceDriver.run(); + testReducerResults(postMapperTotalWeight, results); + } + + @Test + public void testHypercubeMapReduce() throws IOException { + MapReduceDriver<Writable, VectorWritable, IntWritable, CentroidWritable, IntWritable, CentroidWritable> + mapReduceDriver = new MapReduceDriver<>(new StreamingKMeansMapper(), new StreamingKMeansReducer()); + Configuration configuration = mapReduceDriver.getConfiguration(); + configure(configuration); + + System.out.printf("%s full test\n", configuration.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION)); + for (Centroid datapoint : syntheticData.getFirst()) { + mapReduceDriver.addInput(new IntWritable(0), new VectorWritable(datapoint)); + } + List<org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>> results = mapReduceDriver.run(); + testReducerResults(syntheticData.getFirst().size(), results); + } + + @Test + public void testHypercubeMapReduceRunSequentially() throws Exception { + Configuration configuration = getConfiguration(); + configure(configuration); + configuration.set(DefaultOptionCreator.METHOD_OPTION, DefaultOptionCreator.SEQUENTIAL_METHOD); + + Path inputPath = new Path("testInput"); + Path outputPath = new Path("testOutput"); + StreamingKMeansUtilsMR.writeVectorsToSequenceFile(syntheticData.getFirst(), inputPath, configuration); + + StreamingKMeansDriver.run(configuration, inputPath, outputPath); + + testReducerResults(syntheticData.getFirst().size(), + Lists.newArrayList(Iterables.transform( + new SequenceFileIterable<IntWritable, CentroidWritable>(outputPath, configuration), + new Function< + Pair<IntWritable, CentroidWritable>, + org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable>>() { + @Override + public org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> apply( + org.apache.mahout.common.Pair<IntWritable, CentroidWritable> input) { + return new org.apache.hadoop.mrunit.types.Pair<>( + input.getFirst(), input.getSecond()); + } + }))); + } + + private static void testReducerResults(int totalWeight, List<org.apache.hadoop.mrunit.types.Pair<IntWritable, + CentroidWritable>> results) { + int expectedNumClusters = 1 << NUM_DIMENSIONS; + double expectedWeight = (double) totalWeight / expectedNumClusters; + int numClusters = 0; + int numUnbalancedClusters = 0; + int totalReducerWeight = 0; + for (org.apache.hadoop.mrunit.types.Pair<IntWritable, CentroidWritable> result : results) { + if (result.getSecond().getCentroid().getWeight() != expectedWeight) { + System.out.printf("Unbalanced weight %f in centroid %d\n", result.getSecond().getCentroid().getWeight(), + result.getSecond().getCentroid().getIndex()); + ++numUnbalancedClusters; + } + assertEquals("Final centroid index is invalid", numClusters, result.getFirst().get()); + totalReducerWeight += result.getSecond().getCentroid().getWeight(); + ++numClusters; + } + System.out.printf("%d clusters are unbalanced\n", numUnbalancedClusters); + assertEquals("Invalid total weight", totalWeight, totalReducerWeight); + assertEquals("Invalid number of clusters", 1 << NUM_DIMENSIONS, numClusters); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java new file mode 100644 index 0000000..2d790e5 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFilesTest.java @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.tools; + +import com.google.common.collect.Iterables; +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.LocalFileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.junit.Test; + +public class ResplitSequenceFilesTest extends MahoutTestCase { + + @Test + public void testSplitting() throws Exception { + + Path inputFile = new Path(getTestTempDirPath("input"), "test.seq"); + Path output = getTestTempDirPath("output"); + Configuration conf = new Configuration(); + LocalFileSystem fs = FileSystem.getLocal(conf); + + SequenceFile.Writer writer = null; + try { + writer = SequenceFile.createWriter(fs, conf, inputFile, IntWritable.class, IntWritable.class); + writer.append(new IntWritable(1), new IntWritable(1)); + writer.append(new IntWritable(2), new IntWritable(2)); + writer.append(new IntWritable(3), new IntWritable(3)); + writer.append(new IntWritable(4), new IntWritable(4)); + writer.append(new IntWritable(5), new IntWritable(5)); + writer.append(new IntWritable(6), new IntWritable(6)); + writer.append(new IntWritable(7), new IntWritable(7)); + writer.append(new IntWritable(8), new IntWritable(8)); + } finally { + Closeables.close(writer, false); + } + + String splitPattern = "split"; + int numSplits = 4; + + ResplitSequenceFiles.main(new String[] { "--input", inputFile.toString(), + "--output", output.toString() + "/" + splitPattern, "--numSplits", String.valueOf(numSplits) }); + + FileStatus[] statuses = HadoopUtil.getFileStatus(output, PathType.LIST, PathFilters.logsCRCFilter(), null, conf); + + for (FileStatus status : statuses) { + String name = status.getPath().getName(); + assertTrue(name.startsWith(splitPattern)); + assertEquals(2, numEntries(status, conf)); + } + assertEquals(numSplits, statuses.length); + } + + private int numEntries(FileStatus status, Configuration conf) { + return Iterables.size(new SequenceFileIterable(status.getPath(), conf)); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java new file mode 100644 index 0000000..66b66e3 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/PathDirectoryTest.java @@ -0,0 +1,65 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.clustering.topdown; + +import org.apache.hadoop.fs.Path; +import org.apache.mahout.common.MahoutTestCase; +import org.junit.Test; + +import java.io.File; + +public final class PathDirectoryTest extends MahoutTestCase { + + private final Path output = new Path("output"); + + @Test + public void shouldReturnTopLevelClusterPath() { + Path expectedPath = new Path(output, PathDirectory.TOP_LEVEL_CLUSTER_DIRECTORY); + assertEquals(expectedPath, PathDirectory.getTopLevelClusterPath(output)); + } + + @Test + public void shouldReturnClusterPostProcessorOutputDirectory() { + Path expectedPath = new Path(output, PathDirectory.POST_PROCESS_DIRECTORY); + assertEquals(expectedPath, PathDirectory.getClusterPostProcessorOutputDirectory(output)); + } + + @Test + public void shouldReturnClusterOutputClusteredPoints() { + Path expectedPath = new Path(output, PathDirectory.CLUSTERED_POINTS_DIRECTORY + File.separator + '*'); + assertEquals(expectedPath, PathDirectory.getClusterOutputClusteredPoints(output)); + } + + @Test + public void shouldReturnBottomLevelClusterPath() { + Path expectedPath = new Path(output + File.separator + + PathDirectory.BOTTOM_LEVEL_CLUSTER_DIRECTORY + File.separator + + '1'); + assertEquals(expectedPath, PathDirectory.getBottomLevelClusterPath(output, "1")); + } + + @Test + public void shouldReturnClusterPathForClusterId() { + Path expectedPath = new Path(PathDirectory.getClusterPostProcessorOutputDirectory(output), new Path("1")); + assertEquals(expectedPath, PathDirectory.getClusterPathForClusterId( + PathDirectory.getClusterPostProcessorOutputDirectory(output), "1")); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java new file mode 100644 index 0000000..d5a9a90 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReaderTest.java @@ -0,0 +1,121 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import java.io.IOException; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.clustering.canopy.CanopyDriver; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.clustering.kmeans.KMeansDriver; +import org.apache.mahout.common.DummyOutputCollector; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +public final class ClusterCountReaderTest extends MahoutTestCase { + + public static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, {5, 4}, {4, 5}, {5, 5}}; + + private FileSystem fs; + private Path outputPathForCanopy; + private Path outputPathForKMeans; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + Configuration conf = getConfiguration(); + fs = FileSystem.get(conf); + } + + public static List<VectorWritable> getPointsWritable(double[][] raw) { + List<VectorWritable> points = Lists.newArrayList(); + for (double[] fr : raw) { + Vector vec = new RandomAccessSparseVector(fr.length); + vec.assign(fr); + points.add(new VectorWritable(vec)); + } + return points; + } + + /** + * Story: User wants to use cluster post processor after canopy clustering and then run clustering on the + * output clusters + */ + @Test + public void testGetNumberOfClusters() throws Exception { + List<VectorWritable> points = getPointsWritable(REFERENCE); + + Path pointsPath = getTestTempDirPath("points"); + Configuration conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file2"), fs, conf); + + outputPathForCanopy = getTestTempDirPath("canopy"); + outputPathForKMeans = getTestTempDirPath("kmeans"); + + topLevelClustering(pointsPath, conf); + + int numberOfClusters = ClusterCountReader.getNumberOfClusters(outputPathForKMeans, conf); + Assert.assertEquals(2, numberOfClusters); + verifyThatNumberOfClustersIsCorrect(conf, new Path(outputPathForKMeans, new Path("clusteredPoints"))); + + } + + private void topLevelClustering(Path pointsPath, Configuration conf) throws IOException, + InterruptedException, + ClassNotFoundException { + DistanceMeasure measure = new ManhattanDistanceMeasure(); + CanopyDriver.run(conf, pointsPath, outputPathForCanopy, measure, 4.0, 3.0, true, 0.0, true); + Path clustersIn = new Path(outputPathForCanopy, new Path(Cluster.CLUSTERS_DIR + '0' + + Cluster.FINAL_ITERATION_SUFFIX)); + KMeansDriver.run(conf, pointsPath, clustersIn, outputPathForKMeans, 1, 1, true, 0.0, true); + } + + private static void verifyThatNumberOfClustersIsCorrect(Configuration conf, Path clusteredPointsPath) { + DummyOutputCollector<IntWritable,WeightedVectorWritable> collector = + new DummyOutputCollector<>(); + + // The key is the clusterId, the value is the weighted vector + for (Pair<IntWritable,WeightedVectorWritable> record : + new SequenceFileIterable<IntWritable,WeightedVectorWritable>(new Path(clusteredPointsPath, "part-m-0"), + conf)) { + collector.collect(record.getFirst(), record.getSecond()); + } + int clusterSize = collector.getKeys().size(); + assertEquals(2, clusterSize); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java new file mode 100644 index 0000000..0fab2fe --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorTest.java @@ -0,0 +1,205 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import org.apache.commons.lang3.ArrayUtils; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.FileUtil; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.ClusteringTestUtils; +import org.apache.mahout.clustering.canopy.CanopyDriver; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.clustering.topdown.PathDirectory; +import org.apache.mahout.common.DummyOutputCollector; +import org.apache.mahout.common.MahoutTestCase; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.distance.ManhattanDistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import com.google.common.collect.Lists; + +public final class ClusterOutputPostProcessorTest extends MahoutTestCase { + + private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4}, {5, 4}, {4, 5}, {5, 5}}; + + private FileSystem fs; + + private Path outputPath; + + private Configuration conf; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + Configuration conf = getConfiguration(); + fs = FileSystem.get(conf); + } + + private static List<VectorWritable> getPointsWritable(double[][] raw) { + List<VectorWritable> points = Lists.newArrayList(); + for (double[] fr : raw) { + Vector vec = new RandomAccessSparseVector(fr.length); + vec.assign(fr); + points.add(new VectorWritable(vec)); + } + return points; + } + + /** + * Story: User wants to use cluster post processor after canopy clustering and then run clustering on the + * output clusters + */ + @Test + public void testTopDownClustering() throws Exception { + List<VectorWritable> points = getPointsWritable(REFERENCE); + + Path pointsPath = getTestTempDirPath("points"); + conf = getConfiguration(); + ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file1"), fs, conf); + ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, "file2"), fs, conf); + + outputPath = getTestTempDirPath("output"); + + topLevelClustering(pointsPath, conf); + + Map<String,Path> postProcessedClusterDirectories = ouputPostProcessing(conf); + + assertPostProcessedOutput(postProcessedClusterDirectories); + + bottomLevelClustering(postProcessedClusterDirectories); + } + + private void assertTopLevelCluster(Entry<String,Path> cluster) { + String clusterId = cluster.getKey(); + Path clusterPath = cluster.getValue(); + + try { + if ("0".equals(clusterId)) { + assertPointsInFirstTopLevelCluster(clusterPath); + } else if ("1".equals(clusterId)) { + assertPointsInSecondTopLevelCluster(clusterPath); + } + } catch (IOException e) { + Assert.fail("Exception occurred while asserting top level cluster."); + } + + } + + private void assertPointsInFirstTopLevelCluster(Path clusterPath) throws IOException { + List<Vector> vectorsInCluster = getVectorsInCluster(clusterPath); + for (Vector vector : vectorsInCluster) { + Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:1.0,1:1.0}", "{0:2.0,1:1.0}", "{0:1.0,1:2.0}"}, + vector.asFormatString())); + } + } + + private void assertPointsInSecondTopLevelCluster(Path clusterPath) throws IOException { + List<Vector> vectorsInCluster = getVectorsInCluster(clusterPath); + for (Vector vector : vectorsInCluster) { + Assert.assertTrue(ArrayUtils.contains(new String[] {"{0:4.0,1:4.0}", "{0:5.0,1:4.0}", "{0:4.0,1:5.0}", + "{0:5.0,1:5.0}"}, vector.asFormatString())); + } + } + + private List<Vector> getVectorsInCluster(Path clusterPath) throws IOException { + Path[] partFilePaths = FileUtil.stat2Paths(fs.globStatus(clusterPath)); + FileStatus[] listStatus = fs.listStatus(partFilePaths); + List<Vector> vectors = Lists.newArrayList(); + for (FileStatus partFile : listStatus) { + SequenceFile.Reader topLevelClusterReader = new SequenceFile.Reader(fs, partFile.getPath(), conf); + Writable clusterIdAsKey = new LongWritable(); + VectorWritable point = new VectorWritable(); + while (topLevelClusterReader.next(clusterIdAsKey, point)) { + vectors.add(point.get()); + } + } + return vectors; + } + + private void bottomLevelClustering(Map<String,Path> postProcessedClusterDirectories) throws IOException, + InterruptedException, + ClassNotFoundException { + for (Entry<String,Path> topLevelCluster : postProcessedClusterDirectories.entrySet()) { + String clusterId = topLevelCluster.getKey(); + Path topLevelclusterPath = topLevelCluster.getValue(); + + Path bottomLevelCluster = PathDirectory.getBottomLevelClusterPath(outputPath, clusterId); + CanopyDriver.run(conf, topLevelclusterPath, bottomLevelCluster, new ManhattanDistanceMeasure(), 2.1, + 2.0, true, 0.0, true); + assertBottomLevelCluster(bottomLevelCluster); + } + } + + private void assertBottomLevelCluster(Path bottomLevelCluster) { + Path clusteredPointsPath = new Path(bottomLevelCluster, "clusteredPoints"); + + DummyOutputCollector<IntWritable,WeightedVectorWritable> collector = + new DummyOutputCollector<>(); + + // The key is the clusterId, the value is the weighted vector + for (Pair<IntWritable,WeightedVectorWritable> record : + new SequenceFileIterable<IntWritable,WeightedVectorWritable>(new Path(clusteredPointsPath, "part-m-0"), + conf)) { + collector.collect(record.getFirst(), record.getSecond()); + } + int clusterSize = collector.getKeys().size(); + // First top level cluster produces two more clusters, second top level cluster is not broken again + assertTrue(clusterSize == 1 || clusterSize == 2); + + } + + private void assertPostProcessedOutput(Map<String,Path> postProcessedClusterDirectories) { + for (Entry<String,Path> cluster : postProcessedClusterDirectories.entrySet()) { + assertTopLevelCluster(cluster); + } + } + + private Map<String,Path> ouputPostProcessing(Configuration conf) throws IOException { + ClusterOutputPostProcessor clusterOutputPostProcessor = new ClusterOutputPostProcessor(outputPath, + outputPath, conf); + clusterOutputPostProcessor.process(); + return clusterOutputPostProcessor.getPostProcessedClusterDirectories(); + } + + private void topLevelClustering(Path pointsPath, Configuration conf) throws IOException, + InterruptedException, + ClassNotFoundException { + CanopyDriver.run(conf, pointsPath, outputPath, new ManhattanDistanceMeasure(), 3.1, 2.1, true, 0.0, true); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java new file mode 100644 index 0000000..7683b57 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/AbstractJobTest.java @@ -0,0 +1,240 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.common; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.Maps; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.junit.Test; + +public final class AbstractJobTest extends MahoutTestCase { + + interface AbstractJobFactory { + AbstractJob getJob(); + } + + @Test + public void testFlag() throws Exception { + final Map<String,List<String>> testMap = Maps.newHashMap(); + + AbstractJobFactory fact = new AbstractJobFactory() { + @Override + public AbstractJob getJob() { + return new AbstractJob() { + @Override + public int run(String[] args) throws IOException { + addFlag("testFlag", "t", "a simple test flag"); + + Map<String,List<String>> argMap = parseArguments(args); + testMap.clear(); + testMap.putAll(argMap); + return 1; + } + }; + } + }; + + // testFlag will only be present if specified on the command-line + + ToolRunner.run(fact.getJob(), new String[0]); + assertFalse("test map for absent flag", testMap.containsKey("--testFlag")); + + String[] withFlag = { "--testFlag" }; + ToolRunner.run(fact.getJob(), withFlag); + assertTrue("test map for present flag", testMap.containsKey("--testFlag")); + } + + @Test + public void testOptions() throws Exception { + final Map<String,List<String>> testMap = Maps.newHashMap(); + + AbstractJobFactory fact = new AbstractJobFactory() { + @Override + public AbstractJob getJob() { + return new AbstractJob() { + @Override + public int run(String[] args) throws IOException { + this.addOption(DefaultOptionCreator.overwriteOption().create()); + this.addOption("option", "o", "option"); + this.addOption("required", "r", "required", true /* required */); + this.addOption("notRequired", "nr", "not required", false /* not required */); + this.addOption("hasDefault", "hd", "option w/ default", "defaultValue"); + + + Map<String,List<String>> argMap = parseArguments(args); + if (argMap == null) { + return -1; + } + + testMap.clear(); + testMap.putAll(argMap); + + return 0; + } + }; + } + }; + + int ret = ToolRunner.run(fact.getJob(), new String[0]); + assertEquals("-1 for missing required options", -1, ret); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg" + }); + assertEquals("0 for no missing required options", 0, ret); + assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required")); + assertEquals(Collections.singletonList("defaultValue"), testMap.get("--hasDefault")); + assertNull(testMap.get("--option")); + assertNull(testMap.get("--notRequired")); + assertFalse(testMap.containsKey("--overwrite")); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg", + "--unknownArg" + }); + assertEquals("-1 for including unknown options", -1, ret); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg", + "--required", "requiredArg2", + }); + assertEquals("-1 for including duplicate options", -1, ret); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "--required", "requiredArg", + "--overwrite", + "--hasDefault", "nonDefault", + "--option", "optionValue", + "--notRequired", "notRequired" + }); + assertEquals("0 for no missing required options", 0, ret); + assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required")); + assertEquals(Collections.singletonList("nonDefault"), testMap.get("--hasDefault")); + assertEquals(Collections.singletonList("optionValue"), testMap.get("--option")); + assertEquals(Collections.singletonList("notRequired"), testMap.get("--notRequired")); + assertTrue(testMap.containsKey("--overwrite")); + + ret = ToolRunner.run(fact.getJob(), new String[]{ + "-r", "requiredArg", + "-ow", + "-hd", "nonDefault", + "-o", "optionValue", + "-nr", "notRequired" + }); + assertEquals("0 for no missing required options", 0, ret); + assertEquals(Collections.singletonList("requiredArg"), testMap.get("--required")); + assertEquals(Collections.singletonList("nonDefault"), testMap.get("--hasDefault")); + assertEquals(Collections.singletonList("optionValue"), testMap.get("--option")); + assertEquals(Collections.singletonList("notRequired"), testMap.get("--notRequired")); + assertTrue(testMap.containsKey("--overwrite")); + + } + + @Test + public void testInputOutputPaths() throws Exception { + + AbstractJobFactory fact = new AbstractJobFactory() { + @Override + public AbstractJob getJob() { + return new AbstractJob() { + @Override + public int run(String[] args) throws IOException { + addInputOption(); + addOutputOption(); + + // arg map should be null if a required option is missing. + Map<String, List<String>> argMap = parseArguments(args); + + if (argMap == null) { + return -1; + } + + Path inputPath = getInputPath(); + assertNotNull("getInputPath() returns non-null", inputPath); + + Path outputPath = getInputPath(); + assertNotNull("getOutputPath() returns non-null", outputPath); + return 0; + } + }; + } + }; + + int ret = ToolRunner.run(fact.getJob(), new String[0]); + assertEquals("-1 for missing input option", -1, ret); + + String testInputPath = "testInputPath"; + + AbstractJob job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "--input", testInputPath }); + assertEquals("-1 for missing output option", -1, ret); + assertEquals("input path is correct", testInputPath, job.getInputPath().toString()); + + job = fact.getJob(); + String testOutputPath = "testOutputPath"; + ret = ToolRunner.run(job, new String[]{ + "--output", testOutputPath }); + assertEquals("-1 for missing input option", -1, ret); + assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString()); + + job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "--input", testInputPath, "--output", testOutputPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input path is correct", testInputPath, job.getInputPath().toString()); + assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString()); + + job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "--input", testInputPath, "--output", testOutputPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input path is correct", testInputPath, job.getInputPath().toString()); + assertEquals("output path is correct", testOutputPath, job.getOutputPath().toString()); + + job = fact.getJob(); + String testInputPropertyPath = "testInputPropertyPath"; + String testOutputPropertyPath = "testOutputPropertyPath"; + ret = ToolRunner.run(job, new String[]{ + "-Dmapred.input.dir=" + testInputPropertyPath, + "-Dmapred.output.dir=" + testOutputPropertyPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input path from property is correct", testInputPropertyPath, job.getInputPath().toString()); + assertEquals("output path from property is correct", testOutputPropertyPath, job.getOutputPath().toString()); + + job = fact.getJob(); + ret = ToolRunner.run(job, new String[]{ + "-Dmapred.input.dir=" + testInputPropertyPath, + "-Dmapred.output.dir=" + testOutputPropertyPath, + "--input", testInputPath, + "--output", testOutputPath }); + assertEquals("0 for complete options", 0, ret); + assertEquals("input command-line option precedes property", + testInputPath, job.getInputPath().toString()); + assertEquals("output command-line option precedes property", + testOutputPath, job.getOutputPath().toString()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java new file mode 100644 index 0000000..5d3532c --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/DistributedCacheFileLocationTest.java @@ -0,0 +1,46 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import org.apache.hadoop.fs.Path; +import org.junit.Test; + +import java.io.File; +import java.net.URI; + + +public class DistributedCacheFileLocationTest extends MahoutTestCase { + + static final File FILE_I_WANT_TO_FIND = new File("file/i_want_to_find.txt"); + static final URI[] DISTRIBUTED_CACHE_FILES = new URI[] { + new File("/first/file").toURI(), new File("/second/file").toURI(), FILE_I_WANT_TO_FIND.toURI() }; + + @Test + public void nonExistingFile() { + Path path = HadoopUtil.findInCacheByPartOfFilename("no such file", DISTRIBUTED_CACHE_FILES); + assertNull(path); + } + + @Test + public void existingFile() { + Path path = HadoopUtil.findInCacheByPartOfFilename("want_to_find", DISTRIBUTED_CACHE_FILES); + assertNotNull(path); + assertEquals(FILE_I_WANT_TO_FIND.getName(), path.getName()); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java new file mode 100644 index 0000000..8f89623 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyOutputCollector.java @@ -0,0 +1,57 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import com.google.common.collect.Lists; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapred.OutputCollector; + +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +public final class DummyOutputCollector<K extends WritableComparable, V extends Writable> + implements OutputCollector<K,V> { + + private final Map<K, List<V>> data = new TreeMap<>(); + + @Override + public void collect(K key,V values) { + List<V> points = data.get(key); + if (points == null) { + points = Lists.newArrayList(); + data.put(key, points); + } + points.add(values); + } + + public Map<K,List<V>> getData() { + return data; + } + + public List<V> getValue(K key) { + return data.get(key); + } + + public Set<K> getKeys() { + return data.keySet(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java new file mode 100644 index 0000000..61b768a --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriter.java @@ -0,0 +1,223 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import com.google.common.collect.Lists; +import com.google.common.collect.Maps; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.NullWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.MapContext; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.RecordWriter; +import org.apache.hadoop.mapreduce.ReduceContext; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.hadoop.mapreduce.TaskAttemptContext; +import org.apache.hadoop.mapreduce.TaskAttemptID; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.Set; + +public final class DummyRecordWriter<K extends Writable, V extends Writable> extends RecordWriter<K, V> { + + private final List<K> keysInInsertionOrder = Lists.newArrayList(); + private final Map<K, List<V>> data = Maps.newHashMap(); + + @Override + public void write(K key, V value) { + + // if the user reuses the same writable class, we need to create a new one + // otherwise the Map content will be modified after the insert + try { + + K keyToUse = key instanceof NullWritable ? key : (K) cloneWritable(key); + V valueToUse = (V) cloneWritable(value); + + keysInInsertionOrder.add(keyToUse); + + List<V> points = data.get(key); + if (points == null) { + points = Lists.newArrayList(); + data.put(keyToUse, points); + } + points.add(valueToUse); + + } catch (IOException e) { + throw new RuntimeException(e.getMessage(), e); + } + } + + private Writable cloneWritable(Writable original) throws IOException { + + Writable clone; + try { + clone = original.getClass().asSubclass(Writable.class).newInstance(); + } catch (Exception e) { + throw new RuntimeException("Unable to instantiate writable!", e); + } + ByteArrayOutputStream bytes = new ByteArrayOutputStream(); + + original.write(new DataOutputStream(bytes)); + clone.readFields(new DataInputStream(new ByteArrayInputStream(bytes.toByteArray()))); + + return clone; + } + + @Override + public void close(TaskAttemptContext context) { + } + + public Map<K, List<V>> getData() { + return data; + } + + public List<V> getValue(K key) { + return data.get(key); + } + + public Set<K> getKeys() { + return data.keySet(); + } + + public Iterable<K> getKeysInInsertionOrder() { + return keysInInsertionOrder; + } + + public static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context build(Mapper<K1, V1, K2, V2> mapper, + Configuration configuration, + RecordWriter<K2, V2> output) { + + // Use reflection since the context types changed incompatibly between 0.20 + // and 0.23. + try { + return buildNewMapperContext(configuration, output); + } catch (Exception|IncompatibleClassChangeError e) { + try { + return buildOldMapperContext(mapper, configuration, output); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + } + + public static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context build(Reducer<K1, V1, K2, V2> reducer, + Configuration configuration, + RecordWriter<K2, V2> output, + Class<K1> keyClass, + Class<V1> valueClass) { + + // Use reflection since the context types changed incompatibly between 0.20 + // and 0.23. + try { + return buildNewReducerContext(configuration, output, keyClass, valueClass); + } catch (Exception|IncompatibleClassChangeError e) { + try { + return buildOldReducerContext(reducer, configuration, output, keyClass, valueClass); + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context buildNewMapperContext( + Configuration configuration, RecordWriter<K2, V2> output) throws Exception { + Class<?> mapContextImplClass = Class.forName("org.apache.hadoop.mapreduce.task.MapContextImpl"); + Constructor<?> cons = mapContextImplClass.getConstructors()[0]; + Object mapContextImpl = cons.newInstance(configuration, + new TaskAttemptID(), null, output, null, new DummyStatusReporter(), null); + + Class<?> wrappedMapperClass = Class.forName("org.apache.hadoop.mapreduce.lib.map.WrappedMapper"); + Object wrappedMapper = wrappedMapperClass.getConstructor().newInstance(); + Method getMapContext = wrappedMapperClass.getMethod("getMapContext", MapContext.class); + return (Mapper.Context) getMapContext.invoke(wrappedMapper, mapContextImpl); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Mapper<K1, V1, K2, V2>.Context buildOldMapperContext( + Mapper<K1, V1, K2, V2> mapper, Configuration configuration, + RecordWriter<K2, V2> output) throws Exception { + Constructor<?> cons = getNestedContextConstructor(mapper.getClass()); + // first argument to the constructor is the enclosing instance + return (Mapper.Context) cons.newInstance(mapper, configuration, + new TaskAttemptID(), null, output, null, new DummyStatusReporter(), null); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context buildNewReducerContext( + Configuration configuration, RecordWriter<K2, V2> output, Class<K1> keyClass, + Class<V1> valueClass) throws Exception { + Class<?> reduceContextImplClass = Class.forName("org.apache.hadoop.mapreduce.task.ReduceContextImpl"); + Constructor<?> cons = reduceContextImplClass.getConstructors()[0]; + Object reduceContextImpl = cons.newInstance(configuration, + new TaskAttemptID(), + new MockIterator(), + null, + null, + output, + null, + new DummyStatusReporter(), + null, + keyClass, + valueClass); + + Class<?> wrappedReducerClass = Class.forName("org.apache.hadoop.mapreduce.lib.reduce.WrappedReducer"); + Object wrappedReducer = wrappedReducerClass.getConstructor().newInstance(); + Method getReducerContext = wrappedReducerClass.getMethod("getReducerContext", ReduceContext.class); + return (Reducer.Context) getReducerContext.invoke(wrappedReducer, reduceContextImpl); + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static <K1, V1, K2, V2> Reducer<K1, V1, K2, V2>.Context buildOldReducerContext( + Reducer<K1, V1, K2, V2> reducer, Configuration configuration, + RecordWriter<K2, V2> output, Class<K1> keyClass, + Class<V1> valueClass) throws Exception { + Constructor<?> cons = getNestedContextConstructor(reducer.getClass()); + // first argument to the constructor is the enclosing instance + return (Reducer.Context) cons.newInstance(reducer, + configuration, + new TaskAttemptID(), + new MockIterator(), + null, + null, + output, + null, + new DummyStatusReporter(), + null, + keyClass, + valueClass); + } + + private static Constructor<?> getNestedContextConstructor(Class<?> outerClass) { + for (Class<?> nestedClass : outerClass.getClasses()) { + if ("Context".equals(nestedClass.getSimpleName())) { + return nestedClass.getConstructors()[0]; + } + } + throw new IllegalStateException("Cannot find context class for " + outerClass); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java new file mode 100644 index 0000000..1d53cc7 --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyRecordWriterTest.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.common; + +import org.apache.hadoop.io.IntWritable; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.VectorWritable; +import org.junit.Assert; +import org.junit.Test; + +public class DummyRecordWriterTest { + + @Test + public void testWrite() { + DummyRecordWriter<IntWritable, VectorWritable> writer = + new DummyRecordWriter<>(); + IntWritable reusableIntWritable = new IntWritable(); + VectorWritable reusableVectorWritable = new VectorWritable(); + reusableIntWritable.set(0); + reusableVectorWritable.set(new DenseVector(new double[] { 1, 2, 3 })); + writer.write(reusableIntWritable, reusableVectorWritable); + reusableIntWritable.set(1); + reusableVectorWritable.set(new DenseVector(new double[] { 4, 5, 6 })); + writer.write(reusableIntWritable, reusableVectorWritable); + + Assert.assertEquals( + "The writer must remember the two keys that is written to it", 2, + writer.getKeys().size()); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java new file mode 100644 index 0000000..c6bc34b --- /dev/null +++ b/community/mahout-mr/src/test/java/org/apache/mahout/common/DummyStatusReporter.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.mahout.common; + +import org.easymock.EasyMock; + +import java.util.Map; + +import com.google.common.collect.Maps; +import org.apache.hadoop.mapreduce.Counter; +import org.apache.hadoop.mapreduce.StatusReporter; + +public final class DummyStatusReporter extends StatusReporter { + + private final Map<Enum<?>, Counter> counters = Maps.newHashMap(); + private final Map<String, Counter> counterGroups = Maps.newHashMap(); + + private static Counter newCounter() { + try { + // 0.23 case + String c = "org.apache.hadoop.mapreduce.counters.GenericCounter"; + return (Counter) EasyMock.createMockBuilder(Class.forName(c)).createMock(); + } catch (ClassNotFoundException e) { + // 0.20 case + return EasyMock.createMockBuilder(Counter.class).createMock(); + } + } + + @Override + public Counter getCounter(Enum<?> name) { + if (!counters.containsKey(name)) { + counters.put(name, newCounter()); + } + return counters.get(name); + } + + + @Override + public Counter getCounter(String group, String name) { + if (!counterGroups.containsKey(group + name)) { + counterGroups.put(group + name, newCounter()); + } + return counterGroups.get(group+name); + } + + @Override + public void progress() { + } + + @Override + public void setStatus(String status) { + } + + @Override + public float getProgress() { + return 0.0f; + } + +}
