http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java new file mode 100644 index 0000000..25a4022 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.cluster; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.apache.mahout.math.random.Multinomial; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Implements a ball k-means algorithm for weighted vectors with probabilistic seeding similar to k-means++. + * The idea is that k-means++ gives good starting clusters and ball k-means can tune up the final result very nicely + * in only a few passes (or even in a single iteration for well-clusterable data). + * + * A good reference for this class of algorithms is "The Effectiveness of Lloyd-Type Methods for the k-Means Problem" + * by Rafail Ostrovsky, Yuval Rabani, Leonard J. Schulman and Chaitanya Swamy. The code here uses the seeding strategy + * as described in section 4.1.1 of that paper and the ball k-means step as described in section 4.2. We support + * multiple iterations in contrast to the algorithm described in the paper. + */ +public class BallKMeans implements Iterable<Centroid> { + /** + * The searcher containing the centroids. + */ + private final UpdatableSearcher centroids; + + /** + * The number of clusters to cluster the data into. + */ + private final int numClusters; + + /** + * The maximum number of iterations of the algorithm to run waiting for the cluster assignments + * to stabilize. If there are no changes in cluster assignment earlier, we can finish early. + */ + private final int maxNumIterations; + + /** + * When deciding which points to include in the new centroid calculation, + * it's preferable to exclude outliers since it increases the rate of convergence. + * So, we calculate the distance from each cluster to its closest neighboring cluster. When + * evaluating the points assigned to a cluster, we compare the distance between the centroid to + * the point with the distance between the centroid and its closest centroid neighbor + * multiplied by this trimFraction. If the distance between the centroid and the point is + * greater, we consider it an outlier and we don't use it. + */ + private final double trimFraction; + + /** + * Selecting the initial centroids is the most important part of the ball k-means clustering. Poor choices, like two + * centroids in the same actual cluster result in a low-quality final result. + * k-means++ initialization yields good quality clusters, especially when using BallKMeans after StreamingKMeans as + * the points have weights. + * Simple, random selection of the points based on their weights is faster but sometimes fails to produce the + * desired number of clusters. + * This field is true if the initialization should be done with k-means++. + */ + private final boolean kMeansPlusPlusInit; + + /** + * When using trimFraction, the weight of each centroid will not be the sum of the weights of + * the vectors assigned to that cluster because outliers are not used to compute the updated + * centroid. + * So, the total weight is probably wrong. This can be fixed by doing another pass over the + * data points and adjusting the weights of each centroid. This doesn't update the coordinates + * of the centroids, but is useful if the weights matter. + */ + private final boolean correctWeights; + + /** + * When running multiple ball k-means passes to get the one with the smallest total cost, can compute the + * overall cost, using all the points for clustering, or reserve a fraction of them, testProbability in a test set. + * The cost is the sum of the distances between each point and its corresponding centroid. + * We then use this set of points to compute the total cost on. We're therefore trying to select the clustering + * that best describes the underlying distribution of the clusters. + * This field is the probability of assigning a given point to the test set. If this is 0, the cost will be computed + * on the entire set of points. + */ + private final double testProbability; + + /** + * Whether or not testProbability > 0, i.e., there exists a non-empty 'test' set. + */ + private final boolean splitTrainTest; + + /** + * How many k-means runs to have. If there's more than one run, we compute the cost of each clustering as described + * above and select the clustering that minimizes the cost. + * Multiple runs are a lot more useful when using the random initialization. With kmeans++, 1-2 runs are enough and + * more runs are not likely to help quality much. + */ + private final int numRuns; + + /** + * Random object to sample values from. + */ + private final Random random; + + public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations) { + // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end, + // there will be 0 points in the test set and 1 run. + this(searcher, numClusters, maxNumIterations, 0.9, true, true, 0.0, 1); + } + + public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations, + boolean kMeansPlusPlusInit, int numRuns) { + // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end, + // there will be 10% points of in the test set. + this(searcher, numClusters, maxNumIterations, 0.9, kMeansPlusPlusInit, true, 0.1, numRuns); + } + + public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations, + double trimFraction, boolean kMeansPlusPlusInit, boolean correctWeights, + double testProbability, int numRuns) { + Preconditions.checkArgument(searcher.size() == 0, "Searcher must be empty initially to populate with centroids"); + Preconditions.checkArgument(numClusters > 0, "The requested number of clusters must be positive"); + Preconditions.checkArgument(maxNumIterations > 0, "The maximum number of iterations must be positive"); + Preconditions.checkArgument(trimFraction > 0, "The trim fraction must be positive"); + Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "The testProbability must be in [0, 1)"); + Preconditions.checkArgument(numRuns > 0, "There has to be at least one run"); + + this.centroids = searcher; + this.numClusters = numClusters; + this.maxNumIterations = maxNumIterations; + + this.trimFraction = trimFraction; + this.kMeansPlusPlusInit = kMeansPlusPlusInit; + this.correctWeights = correctWeights; + + this.testProbability = testProbability; + this.splitTrainTest = testProbability > 0; + this.numRuns = numRuns; + + this.random = RandomUtils.getRandom(); + } + + public Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> splitTrainTest( + List<? extends WeightedVector> datapoints) { + // If there will be no points assigned to the test set, return now. + if (testProbability == 0) { + return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(datapoints, + Lists.<WeightedVector>newArrayList()); + } + + int numTest = (int) (testProbability * datapoints.size()); + Preconditions.checkArgument(numTest > 0 && numTest < datapoints.size(), + "Must have nonzero number of training and test vectors. Asked for %.1f %% of %d vectors for test", + testProbability * 100, datapoints.size()); + + Collections.shuffle(datapoints); + return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>( + datapoints.subList(numTest, datapoints.size()), datapoints.subList(0, numTest)); + } + + /** + * Clusters the datapoints in the list doing either random seeding of the centroids or k-means++. + * + * @param datapoints the points to be clustered. + * @return an UpdatableSearcher with the resulting clusters. + */ + public UpdatableSearcher cluster(List<? extends WeightedVector> datapoints) { + Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> trainTestSplit = splitTrainTest(datapoints); + List<Vector> bestCentroids = Lists.newArrayList(); + double cost = Double.POSITIVE_INFINITY; + double bestCost = Double.POSITIVE_INFINITY; + for (int i = 0; i < numRuns; ++i) { + centroids.clear(); + if (kMeansPlusPlusInit) { + // Use k-means++ to set initial centroids. + initializeSeedsKMeansPlusPlus(trainTestSplit.getFirst()); + } else { + // Randomly select the initial centroids. + initializeSeedsRandomly(trainTestSplit.getFirst()); + } + // Do k-means iterations with trimmed mean computation (aka ball k-means). + if (numRuns > 1) { + // If the clustering is successful (there are no zero-weight centroids). + iterativeAssignment(trainTestSplit.getFirst()); + // Compute the cost of the clustering and possibly save the centroids. + cost = ClusteringUtils.totalClusterCost( + splitTrainTest ? datapoints : trainTestSplit.getSecond(), centroids); + if (cost < bestCost) { + bestCost = cost; + bestCentroids.clear(); + Iterables.addAll(bestCentroids, centroids); + } + } else { + // If there is only going to be one run, the cost doesn't need to be computed, so we just return the clustering. + iterativeAssignment(datapoints); + return centroids; + } + } + if (bestCost == Double.POSITIVE_INFINITY) { + throw new RuntimeException("No valid clustering was found"); + } + if (cost != bestCost) { + centroids.clear(); + centroids.addAll(bestCentroids); + } + if (correctWeights) { + for (WeightedVector testDatapoint : trainTestSplit.getSecond()) { + WeightedVector closest = (WeightedVector) centroids.searchFirst(testDatapoint, false).getValue(); + closest.setWeight(closest.getWeight() + testDatapoint.getWeight()); + } + } + return centroids; + } + + /** + * Selects some of the original points randomly with probability proportional to their weights. This is much + * less sophisticated than the kmeans++ approach, however it is faster and coupled with + * + * The side effect of this method is to fill the centroids structure itself. + * + * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind. + */ + private void initializeSeedsRandomly(List<? extends WeightedVector> datapoints) { + int numDatapoints = datapoints.size(); + double totalWeight = 0; + for (WeightedVector datapoint : datapoints) { + totalWeight += datapoint.getWeight(); + } + Multinomial<Integer> seedSelector = new Multinomial<>(); + for (int i = 0; i < numDatapoints; ++i) { + seedSelector.add(i, datapoints.get(i).getWeight() / totalWeight); + } + for (int i = 0; i < numClusters; ++i) { + int sample = seedSelector.sample(); + seedSelector.delete(sample); + Centroid centroid = new Centroid(datapoints.get(sample)); + centroid.setIndex(i); + centroids.add(centroid); + } + } + + /** + * Selects some of the original points according to the k-means++ algorithm. The basic idea is that + * points are selected with probability proportional to their distance from any selected point. In + * this version, points have weights which multiply their likelihood of being selected. This is the + * same as if there were as many copies of the same point as indicated by the weight. + * + * This is pretty expensive, but it vastly improves the quality and convergences of the k-means algorithm. + * The basic idea can be made much faster by only processing a random subset of the original points. + * In the context of streaming k-means, the total number of possible seeds will be about k log n so this + * selection will cost O(k^2 (log n)^2) which isn't much worse than the random sampling idea. At + * n = 10^9, the cost of this initialization will be about 10x worse than a reasonable random sampling + * implementation. + * + * The side effect of this method is to fill the centroids structure itself. + * + * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind. + */ + private void initializeSeedsKMeansPlusPlus(List<? extends WeightedVector> datapoints) { + Preconditions.checkArgument(datapoints.size() > 1, "Must have at least two datapoints points to cluster " + + "sensibly"); + Preconditions.checkArgument(datapoints.size() >= numClusters, + String.format("Must have more datapoints [%d] than clusters [%d]", datapoints.size(), numClusters)); + // Compute the centroid of all of the datapoints. This is then used to compute the squared radius of the datapoints. + Centroid center = new Centroid(datapoints.iterator().next()); + for (WeightedVector row : Iterables.skip(datapoints, 1)) { + center.update(row); + } + + // Given the centroid, we can compute \Delta_1^2(X), the total squared distance for the datapoints + // this accelerates seed selection. + double deltaX = 0; + DistanceMeasure distanceMeasure = centroids.getDistanceMeasure(); + for (WeightedVector row : datapoints) { + deltaX += distanceMeasure.distance(row, center); + } + + // Find the first seed c_1 (and conceptually the second, c_2) as might be done in the 2-means clustering so that + // the probability of selecting c_1 and c_2 is proportional to || c_1 - c_2 ||^2. This is done + // by first selecting c_1 with probability: + // + // p(c_1) = sum_{c_1} || c_1 - c_2 ||^2 \over sum_{c_1, c_2} || c_1 - c_2 ||^2 + // + // This can be simplified to: + // + // p(c_1) = \Delta_1^2(X) + n || c_1 - c ||^2 / (2 n \Delta_1^2(X)) + // + // where c = \sum x / n and \Delta_1^2(X) = sum || x - c ||^2 + // + // All subsequent seeds c_i (including c_2) can then be selected from the remaining points with probability + // proportional to Pr(c_i == x_j) = min_{m < i} || c_m - x_j ||^2. + + // Multinomial distribution of vector indices for the selection seeds. These correspond to + // the indices of the vectors in the original datapoints list. + Multinomial<Integer> seedSelector = new Multinomial<>(); + for (int i = 0; i < datapoints.size(); ++i) { + double selectionProbability = + deltaX + datapoints.size() * distanceMeasure.distance(datapoints.get(i), center); + seedSelector.add(i, selectionProbability); + } + + int selected = random.nextInt(datapoints.size()); + Centroid c_1 = new Centroid(datapoints.get(selected).clone()); + c_1.setIndex(0); + // Construct a set of weighted things which can be used for random selection. Initial weights are + // set to the squared distance from c_1 + for (int i = 0; i < datapoints.size(); ++i) { + WeightedVector row = datapoints.get(i); + double w = distanceMeasure.distance(c_1, row) * 2 * Math.log(1 + row.getWeight()); + seedSelector.set(i, w); + } + + // From here, seeds are selected with probability proportional to: + // + // r_i = min_{c_j} || x_i - c_j ||^2 + // + // when we only have c_1, we have already set these distances and as we select each new + // seed, we update the minimum distances. + centroids.add(c_1); + int clusterIndex = 1; + while (centroids.size() < numClusters) { + // Select according to weights. + int seedIndex = seedSelector.sample(); + Centroid nextSeed = new Centroid(datapoints.get(seedIndex)); + nextSeed.setIndex(clusterIndex++); + centroids.add(nextSeed); + // Don't select this one again. + seedSelector.delete(seedIndex); + // Re-weight everything according to the minimum distance to a seed. + for (int currSeedIndex : seedSelector) { + WeightedVector curr = datapoints.get(currSeedIndex); + double newWeight = nextSeed.getWeight() * distanceMeasure.distance(nextSeed, curr); + if (newWeight < seedSelector.getWeight(currSeedIndex)) { + seedSelector.set(currSeedIndex, newWeight); + } + } + } + } + + /** + * Examines the datapoints and updates cluster centers to be the centroid of the nearest datapoints points. To + * compute a new center for cluster c_i, we average all points that are closer than d_i * trimFraction + * where d_i is + * + * d_i = min_j \sqrt ||c_j - c_i||^2 + * + * By ignoring distant points, the centroids converge more quickly to a good approximation of the + * optimal k-means solution (given good starting points). + * + * @param datapoints the points to cluster. + */ + private void iterativeAssignment(List<? extends WeightedVector> datapoints) { + DistanceMeasure distanceMeasure = centroids.getDistanceMeasure(); + // closestClusterDistances.get(i) is the distance from the i'th cluster to its closest + // neighboring cluster. + List<Double> closestClusterDistances = Lists.newArrayListWithExpectedSize(numClusters); + // clusterAssignments[i] == j means that the i'th point is assigned to the j'th cluster. When + // these don't change, we are done. + // Each point is assigned to the invalid "-1" cluster initially. + List<Integer> clusterAssignments = Lists.newArrayList(Collections.nCopies(datapoints.size(), -1)); + + boolean changed = true; + for (int i = 0; changed && i < maxNumIterations; i++) { + changed = false; + // We compute what the distance between each cluster and its closest neighbor is to set a + // proportional distance threshold for points that should be involved in calculating the + // centroid. + closestClusterDistances.clear(); + for (Vector center : centroids) { + // If a centroid has no points assigned to it, the clustering failed. + Vector closestOtherCluster = centroids.searchFirst(center, true).getValue(); + closestClusterDistances.add(distanceMeasure.distance(center, closestOtherCluster)); + } + + // Copies the current cluster centroids to newClusters and sets their weights to 0. This is + // so we calculate the new centroids as we go through the datapoints. + List<Centroid> newCentroids = Lists.newArrayList(); + for (Vector centroid : centroids) { + // need a deep copy because we will mutate these values + Centroid newCentroid = (Centroid)centroid.clone(); + newCentroid.setWeight(0); + newCentroids.add(newCentroid); + } + + // Pass over the datapoints computing new centroids. + for (int j = 0; j < datapoints.size(); ++j) { + WeightedVector datapoint = datapoints.get(j); + // Get the closest cluster this point belongs to. + WeightedThing<Vector> closestPair = centroids.searchFirst(datapoint, false); + int closestIndex = ((WeightedVector) closestPair.getValue()).getIndex(); + double closestDistance = closestPair.getWeight(); + // Update its cluster assignment if necessary. + if (closestIndex != clusterAssignments.get(j)) { + changed = true; + clusterAssignments.set(j, closestIndex); + } + // Only update if the datapoints point is near enough. What this means is that the weight + // of outliers is NOT taken into account and the final weights of the centroids will + // reflect this (it will be less or equal to the initial sum of the weights). + if (closestDistance < trimFraction * closestClusterDistances.get(closestIndex)) { + newCentroids.get(closestIndex).update(datapoint); + } + } + // Add the new centers back into searcher. + centroids.clear(); + centroids.addAll(newCentroids); + } + + if (correctWeights) { + for (Vector v : centroids) { + ((Centroid)v).setWeight(0); + } + for (WeightedVector datapoint : datapoints) { + Centroid closestCentroid = (Centroid) centroids.searchFirst(datapoint, false).getValue(); + closestCentroid.setWeight(closestCentroid.getWeight() + datapoint.getWeight()); + } + } + } + + @Override + public Iterator<Centroid> iterator() { + return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() { + @Override + public Centroid apply(Vector input) { + Preconditions.checkArgument(input instanceof Centroid, "Non-centroid in centroids " + + "searcher"); + //noinspection ConstantConditions + return (Centroid)input; + } + }); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java new file mode 100644 index 0000000..0e3f068 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.cluster; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import com.google.common.collect.Lists; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.jet.math.Constants; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Implements a streaming k-means algorithm for weighted vectors. + * The goal clustering points one at a time, especially useful for MapReduce mappers that get inputs one at a time. + * + * A rough description of the algorithm: + * Suppose there are l clusters at one point and a new point p is added. + * The new point can either be added to one of the existing l clusters or become a new cluster. To decide: + * - let c be the closest cluster to point p; + * - let d be the distance between c and p; + * - if d > distanceCutoff, create a new cluster from p (p is too far away from the clusters to be part of them; + * distanceCutoff represents the largest distance from a point its assigned cluster's centroid); + * - else (d <= distanceCutoff), create a new cluster with probability d / distanceCutoff (the probability of creating + * a new cluster increases as d increases). + * There will be either l points or l + 1 points after processing a new point. + * + * As the number of clusters increases, it will go over the numClusters limit (numClusters represents a recommendation + * for the number of clusters that there should be at the end). To decrease the number of clusters the existing clusters + * are treated as data points and are re-clustered (collapsed). This tends to make the number of clusters go down. + * If the number of clusters is still too high, distanceCutoff is increased. + * + * For more details, see: + * - "Streaming k-means approximation" by N. Ailon, R. Jaiswal, C. Monteleoni + * http://books.nips.cc/papers/files/nips22/NIPS2009_1085.pdf + * - "Fast and Accurate k-means for Large Datasets" by M. Shindler, A. Wong, A. Meyerson, + * http://books.nips.cc/papers/files/nips24/NIPS2011_1271.pdf + */ +public class StreamingKMeans implements Iterable<Centroid> { + /** + * The searcher containing the centroids that resulted from the clustering of points until now. When adding a new + * point we either assign it to one of the existing clusters in this searcher or create a new centroid for it. + */ + private final UpdatableSearcher centroids; + + /** + * The estimated number of clusters to cluster the data in. If the actual number of clusters increases beyond this + * limit, the clusters will be "collapsed" (re-clustered, by treating them as data points). This doesn't happen + * recursively and a collapse might not necessarily make the number of actual clusters drop to less than this limit. + * + * If the goal is clustering a large data set into k clusters, numClusters SHOULD NOT BE SET to k. StreamingKMeans is + * useful to reduce the size of the data set by the mappers so that it can fit into memory in one reducer that runs + * BallKMeans. + * + * It is NOT MEANT to cluster the data into k clusters in one pass because it can't guarantee that there will in fact + * be k clusters in total. This is because of the dynamic nature of numClusters over the course of the runtime. + * To get an exact number of clusters, another clustering algorithm needs to be applied to the results. + */ + private int numClusters; + + /** + * The number of data points seen so far. This is important for re-estimating numClusters when deciding to collapse + * the existing clusters. + */ + private int numProcessedDatapoints = 0; + + /** + * This is the current value of the distance cutoff. Points which are much closer than this to a centroid will stick + * to it almost certainly. Points further than this to any centroid will form a new cluster. + * + * This increases (is multiplied by beta) when a cluster collapse did not make the number of clusters drop to below + * numClusters (it effectively increases the tolerance for cluster compactness discouraging the creation of new + * clusters). Since a collapse only happens when centroids.size() > clusterOvershoot * numClusters, the cutoff + * increases when the collapse didn't at least remove the slack in the number of clusters. + */ + private double distanceCutoff; + + /** + * Parameter that controls the growth of the distanceCutoff. After n increases of the + * distanceCutoff starting at d_0, the final value is d_0 * beta^n (distance cutoffs increase following a geometric + * progression with ratio beta). + */ + private final double beta; + + /** + * Multiplying clusterLogFactor with numProcessedDatapoints gets an estimate of the suggested + * number of clusters. This mirrors the recommended number of clusters for n points where there should be k actual + * clusters, k * log n. In the case of our estimate we use clusterLogFactor * log(numProcessedDataPoints). + * + * It is important to note that numClusters is NOT k. It is an estimate of k * log n. + */ + private final double clusterLogFactor; + + /** + * Centroids are collapsed when the number of clusters becomes greater than clusterOvershoot * numClusters. This + * effectively means having a slack in numClusters so that the actual number of centroids, centroids.size() tracks + * numClusters approximately. The idea is that the actual number of clusters should be at least numClusters but not + * much more (so that we don't end up having 1 cluster / point). + */ + private final double clusterOvershoot; + + /** + * Random object to sample values from. + */ + private final Random random = RandomUtils.getRandom(); + + /** + * Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2). + * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int, + * double, double, double, double) + */ + public StreamingKMeans(UpdatableSearcher searcher, int numClusters) { + this(searcher, numClusters, 1.0 / numClusters, 1.3, 20, 2); + } + + /** + * Calls StreamingKMeans(searcher, numClusters, distanceCutoff, 1.3, 10, 2). + * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int, + * double, double, double, double) + */ + public StreamingKMeans(UpdatableSearcher searcher, int numClusters, double distanceCutoff) { + this(searcher, numClusters, distanceCutoff, 1.3, 20, 2); + } + + /** + * Creates a new StreamingKMeans class given a searcher and the number of clusters to generate. + * + * @param searcher A Searcher that is used for performing nearest neighbor search. It MUST BE + * EMPTY initially because it will be used to keep track of the cluster + * centroids. + * @param numClusters An estimated number of clusters to generate for the data points. + * This can adjusted, but the actual number will depend on the data. The + * @param distanceCutoff The initial distance cutoff representing the value of the + * distance between a point and its closest centroid after which + * the new point will definitely be assigned to a new cluster. + * @param beta Ratio of geometric progression to use when increasing distanceCutoff. After n increases, distanceCutoff + * becomes distanceCutoff * beta^n. A smaller value increases the distanceCutoff less aggressively. + * @param clusterLogFactor Value multiplied with the number of points counted so far estimating the number of clusters + * to aim for. If the final number of clusters is known and this clustering is only for a + * sketch of the data, this can be the final number of clusters, k. + * @param clusterOvershoot Multiplicative slack factor for slowing down the collapse of the clusters. + */ + public StreamingKMeans(UpdatableSearcher searcher, int numClusters, + double distanceCutoff, double beta, double clusterLogFactor, double clusterOvershoot) { + this.centroids = searcher; + this.numClusters = numClusters; + this.distanceCutoff = distanceCutoff; + this.beta = beta; + this.clusterLogFactor = clusterLogFactor; + this.clusterOvershoot = clusterOvershoot; + } + + /** + * @return an Iterator to the Centroids contained in this clusterer. + */ + @Override + public Iterator<Centroid> iterator() { + return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() { + @Override + public Centroid apply(Vector input) { + return (Centroid)input; + } + }); + } + + /** + * Cluster the rows of a matrix, treating them as Centroids with weight 1. + * @param data matrix whose rows are to be clustered. + * @return the UpdatableSearcher containing the resulting centroids. + */ + public UpdatableSearcher cluster(Matrix data) { + return cluster(Iterables.transform(data, new Function<MatrixSlice, Centroid>() { + @Override + public Centroid apply(MatrixSlice input) { + // The key in a Centroid is actually the MatrixSlice's index. + return Centroid.create(input.index(), input.vector()); + } + })); + } + + /** + * Cluster the data points in an Iterable<Centroid>. + * @param datapoints Iterable whose elements are to be clustered. + * @return the UpdatableSearcher containing the resulting centroids. + */ + public UpdatableSearcher cluster(Iterable<Centroid> datapoints) { + return clusterInternal(datapoints, false); + } + + /** + * Cluster one data point. + * @param datapoint to be clustered. + * @return the UpdatableSearcher containing the resulting centroids. + */ + public UpdatableSearcher cluster(final Centroid datapoint) { + return cluster(new Iterable<Centroid>() { + @Override + public Iterator<Centroid> iterator() { + return new Iterator<Centroid>() { + private boolean accessed = false; + + @Override + public boolean hasNext() { + return !accessed; + } + + @Override + public Centroid next() { + accessed = true; + return datapoint; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }); + } + + /** + * @return the number of clusters computed from the points until now. + */ + public int getNumClusters() { + return centroids.size(); + } + + /** + * Internal clustering method that gets called from the other wrappers. + * @param datapoints Iterable of data points to be clustered. + * @param collapseClusters whether this is an "inner" clustering and the datapoints are the previously computed + * centroids. Some logic is different to ensure counters are consistent but it behaves + * nearly the same. + * @return the UpdatableSearcher containing the resulting centroids. + */ + private UpdatableSearcher clusterInternal(Iterable<Centroid> datapoints, boolean collapseClusters) { + Iterator<Centroid> datapointsIterator = datapoints.iterator(); + if (!datapointsIterator.hasNext()) { + return centroids; + } + + int oldNumProcessedDataPoints = numProcessedDatapoints; + // We clear the centroids we have in case of cluster collapse, the old clusters are the + // datapoints but we need to re-cluster them. + if (collapseClusters) { + centroids.clear(); + numProcessedDatapoints = 0; + } + + if (centroids.size() == 0) { + // Assign the first datapoint to the first cluster. + // Adding a vector to a searcher would normally just reference the copy, + // but we could potentially mutate it and so we need to make a clone. + centroids.add(datapointsIterator.next().clone()); + ++numProcessedDatapoints; + } + + // To cluster, we scan the data and either add each point to the nearest group or create a new group. + // when we get too many groups, we need to increase the threshold and rescan our current groups + while (datapointsIterator.hasNext()) { + Centroid row = datapointsIterator.next(); + // Get the closest vector and its weight as a WeightedThing<Vector>. + // The weight of the WeightedThing is the distance to the query and the value is a + // reference to one of the vectors we added to the searcher previously. + WeightedThing<Vector> closestPair = centroids.searchFirst(row, false); + + // We get a uniformly distributed random number between 0 and 1 and compare it with the + // distance to the closest cluster divided by the distanceCutoff. + // This is so that if the closest cluster is further than distanceCutoff, + // closestPair.getWeight() / distanceCutoff > 1 which will trigger the creation of a new + // cluster anyway. + // However, if the ratio is less than 1, we want to create a new cluster with probability + // proportional to the distance to the closest cluster. + double sample = random.nextDouble(); + if (sample < row.getWeight() * closestPair.getWeight() / distanceCutoff) { + // Add new centroid, note that the vector is copied because we may mutate it later. + centroids.add(row.clone()); + } else { + // Merge the new point with the existing centroid. This will update the centroid's actual + // position. + // We know that all the points we inserted in the centroids searcher are (or extend) + // WeightedVector, so the cast will always succeed. + Centroid centroid = (Centroid) closestPair.getValue(); + + // We will update the centroid by removing it from the searcher and reinserting it to + // ensure consistency. + if (!centroids.remove(centroid, Constants.EPSILON)) { + throw new RuntimeException("Unable to remove centroid"); + } + centroid.update(row); + centroids.add(centroid); + + } + ++numProcessedDatapoints; + + if (!collapseClusters && centroids.size() > clusterOvershoot * numClusters) { + numClusters = (int) Math.max(numClusters, clusterLogFactor * Math.log(numProcessedDatapoints)); + + List<Centroid> shuffled = Lists.newArrayList(); + for (Vector vector : centroids) { + shuffled.add((Centroid) vector); + } + Collections.shuffle(shuffled); + // Re-cluster using the shuffled centroids as data points. The centroids member variable + // is modified directly. + clusterInternal(shuffled, true); + + if (centroids.size() > numClusters) { + distanceCutoff *= beta; + } + } + } + + if (collapseClusters) { + numProcessedDatapoints = oldNumProcessedDataPoints; + } + return centroids; + } + + public void reindexCentroids() { + int numCentroids = 0; + for (Centroid centroid : this) { + centroid.setIndex(numCentroids++); + } + } + + /** + * @return the distanceCutoff (an upper bound for the maximum distance within a cluster). + */ + public double getDistanceCutoff() { + return distanceCutoff; + } + + public void setDistanceCutoff(double distanceCutoff) { + this.distanceCutoff = distanceCutoff; + } + + public DistanceMeasure getDistanceMeasure() { + return centroids.getDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java new file mode 100644 index 0000000..a41940b --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +public class CentroidWritable implements Writable { + private Centroid centroid = null; + + public CentroidWritable() {} + + public CentroidWritable(Centroid centroid) { + this.centroid = centroid; + } + + public Centroid getCentroid() { + return centroid; + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + dataOutput.writeInt(centroid.getIndex()); + dataOutput.writeDouble(centroid.getWeight()); + VectorWritable.writeVector(dataOutput, centroid.getVector()); + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + if (centroid == null) { + centroid = read(dataInput); + return; + } + centroid.setIndex(dataInput.readInt()); + centroid.setWeight(dataInput.readDouble()); + centroid.assign(VectorWritable.readVector(dataInput)); + } + + public static Centroid read(DataInput dataInput) throws IOException { + int index = dataInput.readInt(); + double weight = dataInput.readDouble(); + Vector v = VectorWritable.readVector(dataInput); + return new Centroid(index, v, weight); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CentroidWritable)) { + return false; + } + CentroidWritable writable = (CentroidWritable) o; + return centroid.equals(writable.centroid); + } + + @Override + public int hashCode() { + return centroid.hashCode(); + } + + @Override + public String toString() { + return centroid.toString(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java new file mode 100644 index 0000000..73776b9 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java @@ -0,0 +1,493 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.neighborhood.BruteSearch; +import org.apache.mahout.math.neighborhood.ProjectionSearch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Classifies the vectors into different clusters found by the clustering + * algorithm. + */ +public final class StreamingKMeansDriver extends AbstractJob { + /** + * Streaming KMeans options + */ + /** + * The number of cluster that Mappers will use should be \(O(k log n)\) where k is the number of clusters + * to get at the end and n is the number of points to cluster. This doesn't need to be exact. + * It will be adjusted at runtime. + */ + public static final String ESTIMATED_NUM_MAP_CLUSTERS = "estimatedNumMapClusters"; + /** + * The initial estimated distance cutoff between two points for forming new clusters. + * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans + * Defaults to 10e-6. + */ + public static final String ESTIMATED_DISTANCE_CUTOFF = "estimatedDistanceCutoff"; + + /** + * Ball KMeans options + */ + /** + * After mapping finishes, we get an intermediate set of vectors that represent approximate + * clusterings of the data from each Mapper. These can be clustered by the Reducer using + * BallKMeans in memory. This variable is the maximum number of iterations in the final + * BallKMeans algorithm. + * Defaults to 10. + */ + public static final String MAX_NUM_ITERATIONS = "maxNumIterations"; + /** + * The "ball" aspect of ball k-means means that only the closest points to the centroid will actually be used + * for updating. The fraction of the points to be used is those points whose distance to the center is within + * trimFraction * distance to the closest other center. + * Defaults to 0.9. + */ + public static final String TRIM_FRACTION = "trimFraction"; + /** + * Whether to use k-means++ initialization or random initialization of the seed centroids. + * Essentially, k-means++ provides better clusters, but takes longer, whereas random initialization takes less + * time, but produces worse clusters, and tends to fail more often and needs multiple runs to compare to + * k-means++. If set, uses randomInit. + * @see org.apache.mahout.clustering.streaming.cluster.BallKMeans + */ + public static final String RANDOM_INIT = "randomInit"; + /** + * Whether to correct the weights of the centroids after the clustering is done. The weights end up being wrong + * because of the trimFraction and possible train/test splits. In some cases, especially in a pipeline, having + * an accurate count of the weights is useful. If set, ignores the final weights. + */ + public static final String IGNORE_WEIGHTS = "ignoreWeights"; + /** + * The percentage of points that go into the "test" set when evaluating BallKMeans runs in the reducer. + */ + public static final String TEST_PROBABILITY = "testProbability"; + /** + * The percentage of points that go into the "training" set when evaluating BallKMeans runs in the reducer. + */ + public static final String NUM_BALLKMEANS_RUNS = "numBallKMeansRuns"; + + /** + Searcher options + */ + /** + * The Searcher class when performing nearest neighbor search in StreamingKMeans. + * Defaults to ProjectionSearch. + */ + public static final String SEARCHER_CLASS_OPTION = "searcherClass"; + /** + * The number of projections to use when using a projection searcher like ProjectionSearch or + * FastProjectionSearch. Projection searches work by projection the all the vectors on to a set of + * basis vectors and searching for the projected query in that totally ordered set. This + * however can produce false positives (vectors that are closer when projected than they would + * actually be. + * So, there must be more than one projection vectors in the basis. This variable is the number + * of vectors in a basis. + * Defaults to 3 + */ + public static final String NUM_PROJECTIONS_OPTION = "numProjections"; + /** + * When using approximate searches (anything that's not BruteSearch), + * more than just the seemingly closest element must be considered. This variable has different + * meanings depending on the actual Searcher class used but is a measure of how many candidates + * will be considered. + * See the ProjectionSearch, FastProjectionSearch, LocalitySensitiveHashSearch classes for more + * details. + * Defaults to 2. + */ + public static final String SEARCH_SIZE_OPTION = "searchSize"; + + /** + * Whether to run another pass of StreamingKMeans on the reducer's points before BallKMeans. On some data sets + * with a large number of mappers, the intermediate number of clusters passed to the reducer is too large to + * fit into memory directly, hence the option to collapse the clusters further with StreamingKMeans. + */ + public static final String REDUCE_STREAMING_KMEANS = "reduceStreamingKMeans"; + + private static final Logger log = LoggerFactory.getLogger(StreamingKMeansDriver.class); + + public static final float INVALID_DISTANCE_CUTOFF = -1; + + @Override + public int run(String[] args) throws Exception { + // Standard options for any Mahout job. + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.overwriteOption().create()); + + // The number of clusters to create for the data. + addOption(DefaultOptionCreator.numClustersOption().withDescription( + "The k in k-Means. Approximately this many clusters will be generated.").create()); + + // StreamingKMeans (mapper) options + // There will be k final clusters, but in the Map phase to get a good approximation of the data, O(k log n) + // clusters are needed. Since n is the number of data points and not knowable until reading all the vectors, + // provide a decent estimate. + addOption(ESTIMATED_NUM_MAP_CLUSTERS, "km", "The estimated number of clusters to use for the " + + "Map phase of the job when running StreamingKMeans. This should be around k * log(n), " + + "where k is the final number of clusters and n is the total number of data points to " + + "cluster.", String.valueOf(1)); + + addOption(ESTIMATED_DISTANCE_CUTOFF, "e", "The initial estimated distance cutoff between two " + + "points for forming new clusters. If no value is given, it's estimated from the data set", + String.valueOf(INVALID_DISTANCE_CUTOFF)); + + // BallKMeans (reducer) options + addOption(MAX_NUM_ITERATIONS, "mi", "The maximum number of iterations to run for the " + + "BallKMeans algorithm used by the reducer. If no value is given, defaults to 10.", String.valueOf(10)); + + addOption(TRIM_FRACTION, "tf", "The 'ball' aspect of ball k-means means that only the closest points " + + "to the centroid will actually be used for updating. The fraction of the points to be used is those " + + "points whose distance to the center is within trimFraction * distance to the closest other center. " + + "If no value is given, defaults to 0.9.", String.valueOf(0.9)); + + addFlag(RANDOM_INIT, "ri", "Whether to use k-means++ initialization or random initialization " + + "of the seed centroids. Essentially, k-means++ provides better clusters, but takes longer, whereas random " + + "initialization takes less time, but produces worse clusters, and tends to fail more often and needs " + + "multiple runs to compare to k-means++. If set, uses the random initialization."); + + addFlag(IGNORE_WEIGHTS, "iw", "Whether to correct the weights of the centroids after the clustering is done. " + + "The weights end up being wrong because of the trimFraction and possible train/test splits. In some cases, " + + "especially in a pipeline, having an accurate count of the weights is useful. If set, ignores the final " + + "weights"); + + addOption(TEST_PROBABILITY, "testp", "A double value between 0 and 1 that represents the percentage of " + + "points to be used for 'testing' different clustering runs in the final BallKMeans " + + "step. If no value is given, defaults to 0.1", String.valueOf(0.1)); + + addOption(NUM_BALLKMEANS_RUNS, "nbkm", "Number of BallKMeans runs to use at the end to try to cluster the " + + "points. If no value is given, defaults to 4", String.valueOf(4)); + + // Nearest neighbor search options + // The distance measure used for computing the distance between two points. Generally, the + // SquaredEuclideanDistance is used for clustering problems (it's equivalent to CosineDistance for normalized + // vectors). + // WARNING! You can use any metric but most of the literature is for the squared euclidean distance. + addOption(DefaultOptionCreator.distanceMeasureOption().create()); + + // The default searcher should be something more efficient that BruteSearch (ProjectionSearch, ...). See + // o.a.m.math.neighborhood.* + addOption(SEARCHER_CLASS_OPTION, "sc", "The type of searcher to be used when performing nearest " + + "neighbor searches. Defaults to ProjectionSearch.", ProjectionSearch.class.getCanonicalName()); + + // In the original paper, the authors used 1 projection vector. + addOption(NUM_PROJECTIONS_OPTION, "np", "The number of projections considered in estimating the " + + "distances between vectors. Only used when the distance measure requested is either " + + "ProjectionSearch or FastProjectionSearch. If no value is given, defaults to 3.", String.valueOf(3)); + + addOption(SEARCH_SIZE_OPTION, "s", "In more efficient searches (non BruteSearch), " + + "not all distances are calculated for determining the nearest neighbors. The number of " + + "elements whose distances from the query vector is actually computer is proportional to " + + "searchSize. If no value is given, defaults to 1.", String.valueOf(2)); + + addFlag(REDUCE_STREAMING_KMEANS, "rskm", "There might be too many intermediate clusters from the mapper " + + "to fit into memory, so the reducer can run another pass of StreamingKMeans to collapse them down to a " + + "fewer clusters"); + + addOption(DefaultOptionCreator.methodOption().create()); + + if (parseArguments(args) == null) { + return -1; + } + Path output = getOutputPath(); + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + configureOptionsForWorkers(); + run(getConf(), getInputPath(), output); + return 0; + } + + private void configureOptionsForWorkers() throws ClassNotFoundException { + log.info("Starting to configure options for workers"); + + String method = getOption(DefaultOptionCreator.METHOD_OPTION); + + int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)); + + // StreamingKMeans + int estimatedNumMapClusters = Integer.parseInt(getOption(ESTIMATED_NUM_MAP_CLUSTERS)); + float estimatedDistanceCutoff = Float.parseFloat(getOption(ESTIMATED_DISTANCE_CUTOFF)); + + // BallKMeans + int maxNumIterations = Integer.parseInt(getOption(MAX_NUM_ITERATIONS)); + float trimFraction = Float.parseFloat(getOption(TRIM_FRACTION)); + boolean randomInit = hasOption(RANDOM_INIT); + boolean ignoreWeights = hasOption(IGNORE_WEIGHTS); + float testProbability = Float.parseFloat(getOption(TEST_PROBABILITY)); + int numBallKMeansRuns = Integer.parseInt(getOption(NUM_BALLKMEANS_RUNS)); + + // Nearest neighbor search + String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); + String searcherClass = getOption(SEARCHER_CLASS_OPTION); + + // Get more parameters depending on the kind of search class we're working with. BruteSearch + // doesn't need anything else. + // LocalitySensitiveHashSearch and ProjectionSearches need searchSize. + // ProjectionSearches also need the number of projections. + boolean getSearchSize = false; + boolean getNumProjections = false; + if (!searcherClass.equals(BruteSearch.class.getName())) { + getSearchSize = true; + getNumProjections = true; + } + + // The search size to use. This is quite fuzzy and might end up not being configurable at all. + int searchSize = 0; + if (getSearchSize) { + searchSize = Integer.parseInt(getOption(SEARCH_SIZE_OPTION)); + } + + // The number of projections to use. This is only useful in projection searches which + // project the vectors on multiple basis vectors to get distance estimates that are faster to + // calculate. + int numProjections = 0; + if (getNumProjections) { + numProjections = Integer.parseInt(getOption(NUM_PROJECTIONS_OPTION)); + } + + boolean reduceStreamingKMeans = hasOption(REDUCE_STREAMING_KMEANS); + + configureOptionsForWorkers(getConf(), numClusters, + /* StreamingKMeans */ + estimatedNumMapClusters, estimatedDistanceCutoff, + /* BallKMeans */ + maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns, + /* Searcher */ + measureClass, searcherClass, searchSize, numProjections, + method, + reduceStreamingKMeans); + } + + /** + * Checks the parameters for a StreamingKMeans job and prepares a Configuration with them. + * + * @param conf the Configuration to populate + * @param numClusters k, the number of clusters at the end + * @param estimatedNumMapClusters O(k log n), the number of clusters requested from each mapper + * @param estimatedDistanceCutoff an estimate of the minimum distance that separates two clusters (can be smaller and + * will be increased dynamically) + * @param maxNumIterations the maximum number of iterations of BallKMeans + * @param trimFraction the fraction of the points to be considered in updating a ball k-means + * @param randomInit whether to initialize the ball k-means seeds randomly + * @param ignoreWeights whether to ignore the invalid final ball k-means weights + * @param testProbability the percentage of vectors assigned to the test set for selecting the best final centers + * @param numBallKMeansRuns the number of BallKMeans runs in the reducer that determine the centroids to return + * (clusters are computed for the training set and the error is computed on the test set) + * @param measureClass string, name of the distance measure class; theory works for Euclidean-like distances + * @param searcherClass string, name of the searcher that will be used for nearest neighbor search + * @param searchSize the number of closest neighbors to look at for selecting the closest one in approximate nearest + * neighbor searches + * @param numProjections the number of projected vectors to use for faster searching (only useful for ProjectionSearch + * or FastProjectionSearch); @see org.apache.mahout.math.neighborhood.ProjectionSearch + */ + public static void configureOptionsForWorkers(Configuration conf, + int numClusters, + /* StreamingKMeans */ + int estimatedNumMapClusters, float estimatedDistanceCutoff, + /* BallKMeans */ + int maxNumIterations, float trimFraction, boolean randomInit, + boolean ignoreWeights, float testProbability, int numBallKMeansRuns, + /* Searcher */ + String measureClass, String searcherClass, + int searchSize, int numProjections, + String method, + boolean reduceStreamingKMeans) throws ClassNotFoundException { + // Checking preconditions for the parameters. + Preconditions.checkArgument(numClusters > 0, + "Invalid number of clusters requested: " + numClusters + ". Must be: numClusters > 0!"); + + // StreamingKMeans + Preconditions.checkArgument(estimatedNumMapClusters > numClusters, "Invalid number of estimated map " + + "clusters; There must be more than the final number of clusters (k log n vs k)"); + Preconditions.checkArgument(estimatedDistanceCutoff == INVALID_DISTANCE_CUTOFF || estimatedDistanceCutoff > 0, + "estimatedDistanceCutoff must be equal to -1 or must be greater then 0!"); + + // BallKMeans + Preconditions.checkArgument(maxNumIterations > 0, "Must have at least one BallKMeans iteration"); + Preconditions.checkArgument(trimFraction > 0, "trimFraction must be positive"); + Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "test probability is not in the " + + "interval [0, 1)"); + Preconditions.checkArgument(numBallKMeansRuns > 0, "numBallKMeans cannot be negative"); + + // Searcher + if (!searcherClass.contains("Brute")) { + // These tests only make sense when a relevant searcher is being used. + Preconditions.checkArgument(searchSize > 0, "Invalid searchSize. Must be positive."); + if (searcherClass.contains("Projection")) { + Preconditions.checkArgument(numProjections > 0, "Invalid numProjections. Must be positive"); + } + } + + // Setting the parameters in the Configuration. + conf.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, numClusters); + /* StreamingKMeans */ + conf.setInt(ESTIMATED_NUM_MAP_CLUSTERS, estimatedNumMapClusters); + if (estimatedDistanceCutoff != INVALID_DISTANCE_CUTOFF) { + conf.setFloat(ESTIMATED_DISTANCE_CUTOFF, estimatedDistanceCutoff); + } + /* BallKMeans */ + conf.setInt(MAX_NUM_ITERATIONS, maxNumIterations); + conf.setFloat(TRIM_FRACTION, trimFraction); + conf.setBoolean(RANDOM_INIT, randomInit); + conf.setBoolean(IGNORE_WEIGHTS, ignoreWeights); + conf.setFloat(TEST_PROBABILITY, testProbability); + conf.setInt(NUM_BALLKMEANS_RUNS, numBallKMeansRuns); + /* Searcher */ + // Checks if the measureClass is available, throws exception otherwise. + Class.forName(measureClass); + conf.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, measureClass); + // Checks if the searcherClass is available, throws exception otherwise. + Class.forName(searcherClass); + conf.set(SEARCHER_CLASS_OPTION, searcherClass); + conf.setInt(SEARCH_SIZE_OPTION, searchSize); + conf.setInt(NUM_PROJECTIONS_OPTION, numProjections); + conf.set(DefaultOptionCreator.METHOD_OPTION, method); + + conf.setBoolean(REDUCE_STREAMING_KMEANS, reduceStreamingKMeans); + + log.info("Parameters are: [k] numClusters {}; " + + "[SKM] estimatedNumMapClusters {}; estimatedDistanceCutoff {} " + + "[BKM] maxNumIterations {}; trimFraction {}; randomInit {}; ignoreWeights {}; " + + "testProbability {}; numBallKMeansRuns {}; " + + "[S] measureClass {}; searcherClass {}; searcherSize {}; numProjections {}; " + + "method {}; reduceStreamingKMeans {}", numClusters, estimatedNumMapClusters, estimatedDistanceCutoff, + maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns, + measureClass, searcherClass, searchSize, numProjections, method, reduceStreamingKMeans); + } + + /** + * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to + * cluster the input vectors. + * + * @param input the directory pathname for input points. + * @param output the directory pathname for output points. + * @return 0 on success, -1 on failure. + */ + public static int run(Configuration conf, Path input, Path output) + throws IOException, InterruptedException, ClassNotFoundException, ExecutionException { + log.info("Starting StreamingKMeans clustering for vectors in {}; results are output to {}", + input.toString(), output.toString()); + + if (conf.get(DefaultOptionCreator.METHOD_OPTION, + DefaultOptionCreator.MAPREDUCE_METHOD).equals(DefaultOptionCreator.SEQUENTIAL_METHOD)) { + return runSequentially(conf, input, output); + } else { + return runMapReduce(conf, input, output); + } + } + + private static int runSequentially(Configuration conf, Path input, Path output) + throws IOException, ExecutionException, InterruptedException { + long start = System.currentTimeMillis(); + // Run StreamingKMeans step in parallel by spawning 1 thread per input path to process. + ExecutorService pool = Executors.newCachedThreadPool(); + List<Future<Iterable<Centroid>>> intermediateCentroidFutures = Lists.newArrayList(); + for (FileStatus status : HadoopUtil.listStatus(FileSystem.get(conf), input, PathFilters.logsCRCFilter())) { + intermediateCentroidFutures.add(pool.submit(new StreamingKMeansThread(status.getPath(), conf))); + } + log.info("Finished running Mappers"); + // Merge the resulting "mapper" centroids. + List<Centroid> intermediateCentroids = Lists.newArrayList(); + for (Future<Iterable<Centroid>> futureIterable : intermediateCentroidFutures) { + for (Centroid centroid : futureIterable.get()) { + intermediateCentroids.add(centroid); + } + } + pool.shutdown(); + pool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); + log.info("Finished StreamingKMeans"); + SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf, new Path(output, "part-r-00000"), IntWritable.class, + CentroidWritable.class); + int numCentroids = 0; + // Run BallKMeans on the intermediate centroids. + for (Vector finalVector : StreamingKMeansReducer.getBestCentroids(intermediateCentroids, conf)) { + Centroid finalCentroid = (Centroid)finalVector; + writer.append(new IntWritable(numCentroids++), new CentroidWritable(finalCentroid)); + } + writer.close(); + long end = System.currentTimeMillis(); + log.info("Finished BallKMeans. Took {}.", (end - start) / 1000.0); + return 0; + } + + public static int runMapReduce(Configuration conf, Path input, Path output) + throws IOException, ClassNotFoundException, InterruptedException { + // Prepare Job for submission. + Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class, + StreamingKMeansMapper.class, IntWritable.class, CentroidWritable.class, + StreamingKMeansReducer.class, IntWritable.class, CentroidWritable.class, SequenceFileOutputFormat.class, + conf); + job.setJobName(HadoopUtil.getCustomJobName(StreamingKMeansDriver.class.getSimpleName(), job, + StreamingKMeansMapper.class, StreamingKMeansReducer.class)); + + // There is only one reducer so that the intermediate centroids get collected on one + // machine and are clustered in memory to get the right number of clusters. + job.setNumReduceTasks(1); + + // Set the JAR (so that the required libraries are available) and run. + job.setJarByClass(StreamingKMeansDriver.class); + + // Run job! + long start = System.currentTimeMillis(); + if (!job.waitForCompletion(true)) { + return -1; + } + long end = System.currentTimeMillis(); + + log.info("StreamingKMeans clustering complete. Results are in {}. Took {} ms", output.toString(), end - start); + return 0; + } + + /** + * Constructor to be used by the ToolRunner. + */ + private StreamingKMeansDriver() {} + + public static void main(String[] args) throws Exception { + ToolRunner.run(new StreamingKMeansDriver(), args); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java new file mode 100644 index 0000000..ced11ea --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; +import java.util.List; + +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; + +public class StreamingKMeansMapper extends Mapper<Writable, VectorWritable, IntWritable, CentroidWritable> { + private static final int NUM_ESTIMATE_POINTS = 1000; + + /** + * The clusterer object used to cluster the points received by this mapper online. + */ + private StreamingKMeans clusterer; + + /** + * Number of points clustered so far. + */ + private int numPoints = 0; + + private boolean estimateDistanceCutoff = false; + + private List<Centroid> estimatePoints; + + @Override + public void setup(Context context) { + // At this point the configuration received from the Driver is assumed to be valid. + // No other checks are made. + Configuration conf = context.getConfiguration(); + UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf); + int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1); + double estimatedDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, + StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF); + if (estimatedDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) { + estimateDistanceCutoff = true; + estimatePoints = Lists.newArrayList(); + } + // There is no way of estimating the distance cutoff unless we have some data. + clusterer = new StreamingKMeans(searcher, numClusters, estimatedDistanceCutoff); + } + + private void clusterEstimatePoints() { + clusterer.setDistanceCutoff(ClusteringUtils.estimateDistanceCutoff( + estimatePoints, clusterer.getDistanceMeasure())); + clusterer.cluster(estimatePoints); + estimateDistanceCutoff = false; + } + + @Override + public void map(Writable key, VectorWritable point, Context context) { + Centroid centroid = new Centroid(numPoints++, point.get(), 1); + if (estimateDistanceCutoff) { + if (numPoints < NUM_ESTIMATE_POINTS) { + estimatePoints.add(centroid); + } else if (numPoints == NUM_ESTIMATE_POINTS) { + clusterEstimatePoints(); + } + } else { + clusterer.cluster(centroid); + } + } + + @Override + public void cleanup(Context context) throws IOException, InterruptedException { + // We should cluster the points at the end if they haven't yet been clustered. + if (estimateDistanceCutoff) { + clusterEstimatePoints(); + } + // Reindex the centroids before passing them to the reducer. + clusterer.reindexCentroids(); + // All outputs have the same key to go to the same final reducer. + for (Centroid centroid : clusterer) { + context.write(new IntWritable(0), new CentroidWritable(centroid)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java new file mode 100644 index 0000000..2b78acc --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java @@ -0,0 +1,109 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; +import java.util.List; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.clustering.streaming.cluster.BallKMeans; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class StreamingKMeansReducer extends Reducer<IntWritable, CentroidWritable, IntWritable, CentroidWritable> { + + private static final Logger log = LoggerFactory.getLogger(StreamingKMeansReducer.class); + + /** + * Configuration for the MapReduce job. + */ + private Configuration conf; + + @Override + public void setup(Context context) { + // At this point the configuration received from the Driver is assumed to be valid. + // No other checks are made. + conf = context.getConfiguration(); + } + + @Override + public void reduce(IntWritable key, Iterable<CentroidWritable> centroids, + Context context) throws IOException, InterruptedException { + List<Centroid> intermediateCentroids; + // There might be too many intermediate centroids to fit into memory, in which case, we run another pass + // of StreamingKMeans to collapse the clusters further. + if (conf.getBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, false)) { + intermediateCentroids = Lists.newArrayList( + new StreamingKMeansThread(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() { + @Override + public Centroid apply(CentroidWritable input) { + Preconditions.checkNotNull(input); + return input.getCentroid().clone(); + } + }), conf).call()); + } else { + intermediateCentroids = centroidWritablesToList(centroids); + } + + int index = 0; + for (Vector centroid : getBestCentroids(intermediateCentroids, conf)) { + context.write(new IntWritable(index), new CentroidWritable((Centroid) centroid)); + ++index; + } + } + + public static List<Centroid> centroidWritablesToList(Iterable<CentroidWritable> centroids) { + // A new list must be created because Hadoop iterators mutate the contents of the Writable in + // place, without allocating new references when iterating through the centroids Iterable. + return Lists.newArrayList(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() { + @Override + public Centroid apply(CentroidWritable input) { + Preconditions.checkNotNull(input); + return input.getCentroid().clone(); + } + })); + } + + public static Iterable<Vector> getBestCentroids(List<Centroid> centroids, Configuration conf) { + + if (log.isInfoEnabled()) { + log.info("Number of Centroids: {}", centroids.size()); + } + + int numClusters = conf.getInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1); + int maxNumIterations = conf.getInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, 10); + float trimFraction = conf.getFloat(StreamingKMeansDriver.TRIM_FRACTION, 0.9f); + boolean kMeansPlusPlusInit = !conf.getBoolean(StreamingKMeansDriver.RANDOM_INIT, false); + boolean correctWeights = !conf.getBoolean(StreamingKMeansDriver.IGNORE_WEIGHTS, false); + float testProbability = conf.getFloat(StreamingKMeansDriver.TEST_PROBABILITY, 0.1f); + int numRuns = conf.getInt(StreamingKMeansDriver.NUM_BALLKMEANS_RUNS, 3); + + BallKMeans ballKMeansCluster = new BallKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(conf), + numClusters, maxNumIterations, trimFraction, kMeansPlusPlusInit, correctWeights, testProbability, numRuns); + return ballKMeansCluster.cluster(centroids); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/b988c493/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java ---------------------------------------------------------------------- diff --git a/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java new file mode 100644 index 0000000..acb2b56 --- /dev/null +++ b/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.Callable; + +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class StreamingKMeansThread implements Callable<Iterable<Centroid>> { + private static final Logger log = LoggerFactory.getLogger(StreamingKMeansThread.class); + + private static final int NUM_ESTIMATE_POINTS = 1000; + + private final Configuration conf; + private final Iterable<Centroid> dataPoints; + + public StreamingKMeansThread(Path input, Configuration conf) { + this(StreamingKMeansUtilsMR.getCentroidsFromVectorWritable( + new SequenceFileValueIterable<VectorWritable>(input, false, conf)), conf); + } + + public StreamingKMeansThread(Iterable<Centroid> dataPoints, Configuration conf) { + this.dataPoints = dataPoints; + this.conf = conf; + } + + @Override + public Iterable<Centroid> call() { + UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf); + int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1); + double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, + StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF); + + Iterator<Centroid> dataPointsIterator = dataPoints.iterator(); + + if (estimateDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) { + List<Centroid> estimatePoints = Lists.newArrayListWithExpectedSize(NUM_ESTIMATE_POINTS); + while (dataPointsIterator.hasNext() && estimatePoints.size() < NUM_ESTIMATE_POINTS) { + Centroid centroid = dataPointsIterator.next(); + estimatePoints.add(centroid); + } + + if (log.isInfoEnabled()) { + log.info("Estimated Points: {}", estimatePoints.size()); + } + estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff(estimatePoints, searcher.getDistanceMeasure()); + } + + StreamingKMeans streamingKMeans = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff); + + // datapointsIterator could be empty if no estimate distance was initially provided + // hence creating the iterator again here for the clustering + if (!dataPointsIterator.hasNext()) { + dataPointsIterator = dataPoints.iterator(); + } + + while (dataPointsIterator.hasNext()) { + streamingKMeans.cluster(dataPointsIterator.next()); + } + + streamingKMeans.reindexCentroids(); + return streamingKMeans; + } + +}
