http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java new file mode 100644 index 0000000..0f6f7f2 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansDriver.java @@ -0,0 +1,493 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; + +import com.google.common.base.Preconditions; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.neighborhood.BruteSearch; +import org.apache.mahout.math.neighborhood.ProjectionSearch; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Classifies the vectors into different clusters found by the clustering + * algorithm. + */ +public final class StreamingKMeansDriver extends AbstractJob { + /** + * Streaming KMeans options + */ + /** + * The number of cluster that Mappers will use should be \(O(k log n)\) where k is the number of clusters + * to get at the end and n is the number of points to cluster. This doesn't need to be exact. + * It will be adjusted at runtime. + */ + public static final String ESTIMATED_NUM_MAP_CLUSTERS = "estimatedNumMapClusters"; + /** + * The initial estimated distance cutoff between two points for forming new clusters. + * @see org.apache.mahout.clustering.streaming.cluster.StreamingKMeans + * Defaults to 10e-6. + */ + public static final String ESTIMATED_DISTANCE_CUTOFF = "estimatedDistanceCutoff"; + + /** + * Ball KMeans options + */ + /** + * After mapping finishes, we get an intermediate set of vectors that represent approximate + * clusterings of the data from each Mapper. These can be clustered by the Reducer using + * BallKMeans in memory. This variable is the maximum number of iterations in the final + * BallKMeans algorithm. + * Defaults to 10. + */ + public static final String MAX_NUM_ITERATIONS = "maxNumIterations"; + /** + * The "ball" aspect of ball k-means means that only the closest points to the centroid will actually be used + * for updating. The fraction of the points to be used is those points whose distance to the center is within + * trimFraction * distance to the closest other center. + * Defaults to 0.9. + */ + public static final String TRIM_FRACTION = "trimFraction"; + /** + * Whether to use k-means++ initialization or random initialization of the seed centroids. + * Essentially, k-means++ provides better clusters, but takes longer, whereas random initialization takes less + * time, but produces worse clusters, and tends to fail more often and needs multiple runs to compare to + * k-means++. If set, uses randomInit. + * @see org.apache.mahout.clustering.streaming.cluster.BallKMeans + */ + public static final String RANDOM_INIT = "randomInit"; + /** + * Whether to correct the weights of the centroids after the clustering is done. The weights end up being wrong + * because of the trimFraction and possible train/test splits. In some cases, especially in a pipeline, having + * an accurate count of the weights is useful. If set, ignores the final weights. + */ + public static final String IGNORE_WEIGHTS = "ignoreWeights"; + /** + * The percentage of points that go into the "test" set when evaluating BallKMeans runs in the reducer. + */ + public static final String TEST_PROBABILITY = "testProbability"; + /** + * The percentage of points that go into the "training" set when evaluating BallKMeans runs in the reducer. + */ + public static final String NUM_BALLKMEANS_RUNS = "numBallKMeansRuns"; + + /** + Searcher options + */ + /** + * The Searcher class when performing nearest neighbor search in StreamingKMeans. + * Defaults to ProjectionSearch. + */ + public static final String SEARCHER_CLASS_OPTION = "searcherClass"; + /** + * The number of projections to use when using a projection searcher like ProjectionSearch or + * FastProjectionSearch. Projection searches work by projection the all the vectors on to a set of + * basis vectors and searching for the projected query in that totally ordered set. This + * however can produce false positives (vectors that are closer when projected than they would + * actually be. + * So, there must be more than one projection vectors in the basis. This variable is the number + * of vectors in a basis. + * Defaults to 3 + */ + public static final String NUM_PROJECTIONS_OPTION = "numProjections"; + /** + * When using approximate searches (anything that's not BruteSearch), + * more than just the seemingly closest element must be considered. This variable has different + * meanings depending on the actual Searcher class used but is a measure of how many candidates + * will be considered. + * See the ProjectionSearch, FastProjectionSearch, LocalitySensitiveHashSearch classes for more + * details. + * Defaults to 2. + */ + public static final String SEARCH_SIZE_OPTION = "searchSize"; + + /** + * Whether to run another pass of StreamingKMeans on the reducer's points before BallKMeans. On some data sets + * with a large number of mappers, the intermediate number of clusters passed to the reducer is too large to + * fit into memory directly, hence the option to collapse the clusters further with StreamingKMeans. + */ + public static final String REDUCE_STREAMING_KMEANS = "reduceStreamingKMeans"; + + private static final Logger log = LoggerFactory.getLogger(StreamingKMeansDriver.class); + + public static final float INVALID_DISTANCE_CUTOFF = -1; + + @Override + public int run(String[] args) throws Exception { + // Standard options for any Mahout job. + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.overwriteOption().create()); + + // The number of clusters to create for the data. + addOption(DefaultOptionCreator.numClustersOption().withDescription( + "The k in k-Means. Approximately this many clusters will be generated.").create()); + + // StreamingKMeans (mapper) options + // There will be k final clusters, but in the Map phase to get a good approximation of the data, O(k log n) + // clusters are needed. Since n is the number of data points and not knowable until reading all the vectors, + // provide a decent estimate. + addOption(ESTIMATED_NUM_MAP_CLUSTERS, "km", "The estimated number of clusters to use for the " + + "Map phase of the job when running StreamingKMeans. This should be around k * log(n), " + + "where k is the final number of clusters and n is the total number of data points to " + + "cluster.", String.valueOf(1)); + + addOption(ESTIMATED_DISTANCE_CUTOFF, "e", "The initial estimated distance cutoff between two " + + "points for forming new clusters. If no value is given, it's estimated from the data set", + String.valueOf(INVALID_DISTANCE_CUTOFF)); + + // BallKMeans (reducer) options + addOption(MAX_NUM_ITERATIONS, "mi", "The maximum number of iterations to run for the " + + "BallKMeans algorithm used by the reducer. If no value is given, defaults to 10.", String.valueOf(10)); + + addOption(TRIM_FRACTION, "tf", "The 'ball' aspect of ball k-means means that only the closest points " + + "to the centroid will actually be used for updating. The fraction of the points to be used is those " + + "points whose distance to the center is within trimFraction * distance to the closest other center. " + + "If no value is given, defaults to 0.9.", String.valueOf(0.9)); + + addFlag(RANDOM_INIT, "ri", "Whether to use k-means++ initialization or random initialization " + + "of the seed centroids. Essentially, k-means++ provides better clusters, but takes longer, whereas random " + + "initialization takes less time, but produces worse clusters, and tends to fail more often and needs " + + "multiple runs to compare to k-means++. If set, uses the random initialization."); + + addFlag(IGNORE_WEIGHTS, "iw", "Whether to correct the weights of the centroids after the clustering is done. " + + "The weights end up being wrong because of the trimFraction and possible train/test splits. In some cases, " + + "especially in a pipeline, having an accurate count of the weights is useful. If set, ignores the final " + + "weights"); + + addOption(TEST_PROBABILITY, "testp", "A double value between 0 and 1 that represents the percentage of " + + "points to be used for 'testing' different clustering runs in the final BallKMeans " + + "step. If no value is given, defaults to 0.1", String.valueOf(0.1)); + + addOption(NUM_BALLKMEANS_RUNS, "nbkm", "Number of BallKMeans runs to use at the end to try to cluster the " + + "points. If no value is given, defaults to 4", String.valueOf(4)); + + // Nearest neighbor search options + // The distance measure used for computing the distance between two points. Generally, the + // SquaredEuclideanDistance is used for clustering problems (it's equivalent to CosineDistance for normalized + // vectors). + // WARNING! You can use any metric but most of the literature is for the squared euclidean distance. + addOption(DefaultOptionCreator.distanceMeasureOption().create()); + + // The default searcher should be something more efficient that BruteSearch (ProjectionSearch, ...). See + // o.a.m.math.neighborhood.* + addOption(SEARCHER_CLASS_OPTION, "sc", "The type of searcher to be used when performing nearest " + + "neighbor searches. Defaults to ProjectionSearch.", ProjectionSearch.class.getCanonicalName()); + + // In the original paper, the authors used 1 projection vector. + addOption(NUM_PROJECTIONS_OPTION, "np", "The number of projections considered in estimating the " + + "distances between vectors. Only used when the distance measure requested is either " + + "ProjectionSearch or FastProjectionSearch. If no value is given, defaults to 3.", String.valueOf(3)); + + addOption(SEARCH_SIZE_OPTION, "s", "In more efficient searches (non BruteSearch), " + + "not all distances are calculated for determining the nearest neighbors. The number of " + + "elements whose distances from the query vector is actually computer is proportional to " + + "searchSize. If no value is given, defaults to 1.", String.valueOf(2)); + + addFlag(REDUCE_STREAMING_KMEANS, "rskm", "There might be too many intermediate clusters from the mapper " + + "to fit into memory, so the reducer can run another pass of StreamingKMeans to collapse them down to a " + + "fewer clusters"); + + addOption(DefaultOptionCreator.methodOption().create()); + + if (parseArguments(args) == null) { + return -1; + } + Path output = getOutputPath(); + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + configureOptionsForWorkers(); + run(getConf(), getInputPath(), output); + return 0; + } + + private void configureOptionsForWorkers() throws ClassNotFoundException { + log.info("Starting to configure options for workers"); + + String method = getOption(DefaultOptionCreator.METHOD_OPTION); + + int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)); + + // StreamingKMeans + int estimatedNumMapClusters = Integer.parseInt(getOption(ESTIMATED_NUM_MAP_CLUSTERS)); + float estimatedDistanceCutoff = Float.parseFloat(getOption(ESTIMATED_DISTANCE_CUTOFF)); + + // BallKMeans + int maxNumIterations = Integer.parseInt(getOption(MAX_NUM_ITERATIONS)); + float trimFraction = Float.parseFloat(getOption(TRIM_FRACTION)); + boolean randomInit = hasOption(RANDOM_INIT); + boolean ignoreWeights = hasOption(IGNORE_WEIGHTS); + float testProbability = Float.parseFloat(getOption(TEST_PROBABILITY)); + int numBallKMeansRuns = Integer.parseInt(getOption(NUM_BALLKMEANS_RUNS)); + + // Nearest neighbor search + String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); + String searcherClass = getOption(SEARCHER_CLASS_OPTION); + + // Get more parameters depending on the kind of search class we're working with. BruteSearch + // doesn't need anything else. + // LocalitySensitiveHashSearch and ProjectionSearches need searchSize. + // ProjectionSearches also need the number of projections. + boolean getSearchSize = false; + boolean getNumProjections = false; + if (!searcherClass.equals(BruteSearch.class.getName())) { + getSearchSize = true; + getNumProjections = true; + } + + // The search size to use. This is quite fuzzy and might end up not being configurable at all. + int searchSize = 0; + if (getSearchSize) { + searchSize = Integer.parseInt(getOption(SEARCH_SIZE_OPTION)); + } + + // The number of projections to use. This is only useful in projection searches which + // project the vectors on multiple basis vectors to get distance estimates that are faster to + // calculate. + int numProjections = 0; + if (getNumProjections) { + numProjections = Integer.parseInt(getOption(NUM_PROJECTIONS_OPTION)); + } + + boolean reduceStreamingKMeans = hasOption(REDUCE_STREAMING_KMEANS); + + configureOptionsForWorkers(getConf(), numClusters, + /* StreamingKMeans */ + estimatedNumMapClusters, estimatedDistanceCutoff, + /* BallKMeans */ + maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns, + /* Searcher */ + measureClass, searcherClass, searchSize, numProjections, + method, + reduceStreamingKMeans); + } + + /** + * Checks the parameters for a StreamingKMeans job and prepares a Configuration with them. + * + * @param conf the Configuration to populate + * @param numClusters k, the number of clusters at the end + * @param estimatedNumMapClusters O(k log n), the number of clusters requested from each mapper + * @param estimatedDistanceCutoff an estimate of the minimum distance that separates two clusters (can be smaller and + * will be increased dynamically) + * @param maxNumIterations the maximum number of iterations of BallKMeans + * @param trimFraction the fraction of the points to be considered in updating a ball k-means + * @param randomInit whether to initialize the ball k-means seeds randomly + * @param ignoreWeights whether to ignore the invalid final ball k-means weights + * @param testProbability the percentage of vectors assigned to the test set for selecting the best final centers + * @param numBallKMeansRuns the number of BallKMeans runs in the reducer that determine the centroids to return + * (clusters are computed for the training set and the error is computed on the test set) + * @param measureClass string, name of the distance measure class; theory works for Euclidean-like distances + * @param searcherClass string, name of the searcher that will be used for nearest neighbor search + * @param searchSize the number of closest neighbors to look at for selecting the closest one in approximate nearest + * neighbor searches + * @param numProjections the number of projected vectors to use for faster searching (only useful for ProjectionSearch + * or FastProjectionSearch); @see org.apache.mahout.math.neighborhood.ProjectionSearch + */ + public static void configureOptionsForWorkers(Configuration conf, + int numClusters, + /* StreamingKMeans */ + int estimatedNumMapClusters, float estimatedDistanceCutoff, + /* BallKMeans */ + int maxNumIterations, float trimFraction, boolean randomInit, + boolean ignoreWeights, float testProbability, int numBallKMeansRuns, + /* Searcher */ + String measureClass, String searcherClass, + int searchSize, int numProjections, + String method, + boolean reduceStreamingKMeans) throws ClassNotFoundException { + // Checking preconditions for the parameters. + Preconditions.checkArgument(numClusters > 0, + "Invalid number of clusters requested: " + numClusters + ". Must be: numClusters > 0!"); + + // StreamingKMeans + Preconditions.checkArgument(estimatedNumMapClusters > numClusters, "Invalid number of estimated map " + + "clusters; There must be more than the final number of clusters (k log n vs k)"); + Preconditions.checkArgument(estimatedDistanceCutoff == INVALID_DISTANCE_CUTOFF || estimatedDistanceCutoff > 0, + "estimatedDistanceCutoff must be equal to -1 or must be greater then 0!"); + + // BallKMeans + Preconditions.checkArgument(maxNumIterations > 0, "Must have at least one BallKMeans iteration"); + Preconditions.checkArgument(trimFraction > 0, "trimFraction must be positive"); + Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "test probability is not in the " + + "interval [0, 1)"); + Preconditions.checkArgument(numBallKMeansRuns > 0, "numBallKMeans cannot be negative"); + + // Searcher + if (!searcherClass.contains("Brute")) { + // These tests only make sense when a relevant searcher is being used. + Preconditions.checkArgument(searchSize > 0, "Invalid searchSize. Must be positive."); + if (searcherClass.contains("Projection")) { + Preconditions.checkArgument(numProjections > 0, "Invalid numProjections. Must be positive"); + } + } + + // Setting the parameters in the Configuration. + conf.setInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, numClusters); + /* StreamingKMeans */ + conf.setInt(ESTIMATED_NUM_MAP_CLUSTERS, estimatedNumMapClusters); + if (estimatedDistanceCutoff != INVALID_DISTANCE_CUTOFF) { + conf.setFloat(ESTIMATED_DISTANCE_CUTOFF, estimatedDistanceCutoff); + } + /* BallKMeans */ + conf.setInt(MAX_NUM_ITERATIONS, maxNumIterations); + conf.setFloat(TRIM_FRACTION, trimFraction); + conf.setBoolean(RANDOM_INIT, randomInit); + conf.setBoolean(IGNORE_WEIGHTS, ignoreWeights); + conf.setFloat(TEST_PROBABILITY, testProbability); + conf.setInt(NUM_BALLKMEANS_RUNS, numBallKMeansRuns); + /* Searcher */ + // Checks if the measureClass is available, throws exception otherwise. + Class.forName(measureClass); + conf.set(DefaultOptionCreator.DISTANCE_MEASURE_OPTION, measureClass); + // Checks if the searcherClass is available, throws exception otherwise. + Class.forName(searcherClass); + conf.set(SEARCHER_CLASS_OPTION, searcherClass); + conf.setInt(SEARCH_SIZE_OPTION, searchSize); + conf.setInt(NUM_PROJECTIONS_OPTION, numProjections); + conf.set(DefaultOptionCreator.METHOD_OPTION, method); + + conf.setBoolean(REDUCE_STREAMING_KMEANS, reduceStreamingKMeans); + + log.info("Parameters are: [k] numClusters {}; " + + "[SKM] estimatedNumMapClusters {}; estimatedDistanceCutoff {} " + + "[BKM] maxNumIterations {}; trimFraction {}; randomInit {}; ignoreWeights {}; " + + "testProbability {}; numBallKMeansRuns {}; " + + "[S] measureClass {}; searcherClass {}; searcherSize {}; numProjections {}; " + + "method {}; reduceStreamingKMeans {}", numClusters, estimatedNumMapClusters, estimatedDistanceCutoff, + maxNumIterations, trimFraction, randomInit, ignoreWeights, testProbability, numBallKMeansRuns, + measureClass, searcherClass, searchSize, numProjections, method, reduceStreamingKMeans); + } + + /** + * Iterate over the input vectors to produce clusters and, if requested, use the results of the final iteration to + * cluster the input vectors. + * + * @param input the directory pathname for input points. + * @param output the directory pathname for output points. + * @return 0 on success, -1 on failure. + */ + public static int run(Configuration conf, Path input, Path output) + throws IOException, InterruptedException, ClassNotFoundException, ExecutionException { + log.info("Starting StreamingKMeans clustering for vectors in {}; results are output to {}", + input.toString(), output.toString()); + + if (conf.get(DefaultOptionCreator.METHOD_OPTION, + DefaultOptionCreator.MAPREDUCE_METHOD).equals(DefaultOptionCreator.SEQUENTIAL_METHOD)) { + return runSequentially(conf, input, output); + } else { + return runMapReduce(conf, input, output); + } + } + + private static int runSequentially(Configuration conf, Path input, Path output) + throws IOException, ExecutionException, InterruptedException { + long start = System.currentTimeMillis(); + // Run StreamingKMeans step in parallel by spawning 1 thread per input path to process. + ExecutorService pool = Executors.newCachedThreadPool(); + List<Future<Iterable<Centroid>>> intermediateCentroidFutures = new ArrayList<>(); + for (FileStatus status : HadoopUtil.listStatus(FileSystem.get(conf), input, PathFilters.logsCRCFilter())) { + intermediateCentroidFutures.add(pool.submit(new StreamingKMeansThread(status.getPath(), conf))); + } + log.info("Finished running Mappers"); + // Merge the resulting "mapper" centroids. + List<Centroid> intermediateCentroids = new ArrayList<>(); + for (Future<Iterable<Centroid>> futureIterable : intermediateCentroidFutures) { + for (Centroid centroid : futureIterable.get()) { + intermediateCentroids.add(centroid); + } + } + pool.shutdown(); + pool.awaitTermination(Long.MAX_VALUE, TimeUnit.SECONDS); + log.info("Finished StreamingKMeans"); + SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf, new Path(output, "part-r-00000"), IntWritable.class, + CentroidWritable.class); + int numCentroids = 0; + // Run BallKMeans on the intermediate centroids. + for (Vector finalVector : StreamingKMeansReducer.getBestCentroids(intermediateCentroids, conf)) { + Centroid finalCentroid = (Centroid)finalVector; + writer.append(new IntWritable(numCentroids++), new CentroidWritable(finalCentroid)); + } + writer.close(); + long end = System.currentTimeMillis(); + log.info("Finished BallKMeans. Took {}.", (end - start) / 1000.0); + return 0; + } + + public static int runMapReduce(Configuration conf, Path input, Path output) + throws IOException, ClassNotFoundException, InterruptedException { + // Prepare Job for submission. + Job job = HadoopUtil.prepareJob(input, output, SequenceFileInputFormat.class, + StreamingKMeansMapper.class, IntWritable.class, CentroidWritable.class, + StreamingKMeansReducer.class, IntWritable.class, CentroidWritable.class, SequenceFileOutputFormat.class, + conf); + job.setJobName(HadoopUtil.getCustomJobName(StreamingKMeansDriver.class.getSimpleName(), job, + StreamingKMeansMapper.class, StreamingKMeansReducer.class)); + + // There is only one reducer so that the intermediate centroids get collected on one + // machine and are clustered in memory to get the right number of clusters. + job.setNumReduceTasks(1); + + // Set the JAR (so that the required libraries are available) and run. + job.setJarByClass(StreamingKMeansDriver.class); + + // Run job! + long start = System.currentTimeMillis(); + if (!job.waitForCompletion(true)) { + return -1; + } + long end = System.currentTimeMillis(); + + log.info("StreamingKMeans clustering complete. Results are in {}. Took {} ms", output.toString(), end - start); + return 0; + } + + /** + * Constructor to be used by the ToolRunner. + */ + private StreamingKMeansDriver() {} + + public static void main(String[] args) throws Exception { + ToolRunner.run(new StreamingKMeansDriver(), args); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java new file mode 100644 index 0000000..f12a876 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansMapper.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; + +public class StreamingKMeansMapper extends Mapper<Writable, VectorWritable, IntWritable, CentroidWritable> { + private static final int NUM_ESTIMATE_POINTS = 1000; + + /** + * The clusterer object used to cluster the points received by this mapper online. + */ + private StreamingKMeans clusterer; + + /** + * Number of points clustered so far. + */ + private int numPoints = 0; + + private boolean estimateDistanceCutoff = false; + + private List<Centroid> estimatePoints; + + @Override + public void setup(Context context) { + // At this point the configuration received from the Driver is assumed to be valid. + // No other checks are made. + Configuration conf = context.getConfiguration(); + UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf); + int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1); + double estimatedDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, + StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF); + if (estimatedDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) { + estimateDistanceCutoff = true; + estimatePoints = new ArrayList<>(); + } + // There is no way of estimating the distance cutoff unless we have some data. + clusterer = new StreamingKMeans(searcher, numClusters, estimatedDistanceCutoff); + } + + private void clusterEstimatePoints() { + clusterer.setDistanceCutoff(ClusteringUtils.estimateDistanceCutoff( + estimatePoints, clusterer.getDistanceMeasure())); + clusterer.cluster(estimatePoints); + estimateDistanceCutoff = false; + } + + @Override + public void map(Writable key, VectorWritable point, Context context) { + Centroid centroid = new Centroid(numPoints++, point.get(), 1); + if (estimateDistanceCutoff) { + if (numPoints < NUM_ESTIMATE_POINTS) { + estimatePoints.add(centroid); + } else if (numPoints == NUM_ESTIMATE_POINTS) { + clusterEstimatePoints(); + } + } else { + clusterer.cluster(centroid); + } + } + + @Override + public void cleanup(Context context) throws IOException, InterruptedException { + // We should cluster the points at the end if they haven't yet been clustered. + if (estimateDistanceCutoff) { + clusterEstimatePoints(); + } + // Reindex the centroids before passing them to the reducer. + clusterer.reindexCentroids(); + // All outputs have the same key to go to the same final reducer. + for (Centroid centroid : clusterer) { + context.write(new IntWritable(0), new CentroidWritable(centroid)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java new file mode 100644 index 0000000..2b78acc --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansReducer.java @@ -0,0 +1,109 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; +import java.util.List; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import com.google.common.collect.Lists; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.clustering.streaming.cluster.BallKMeans; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class StreamingKMeansReducer extends Reducer<IntWritable, CentroidWritable, IntWritable, CentroidWritable> { + + private static final Logger log = LoggerFactory.getLogger(StreamingKMeansReducer.class); + + /** + * Configuration for the MapReduce job. + */ + private Configuration conf; + + @Override + public void setup(Context context) { + // At this point the configuration received from the Driver is assumed to be valid. + // No other checks are made. + conf = context.getConfiguration(); + } + + @Override + public void reduce(IntWritable key, Iterable<CentroidWritable> centroids, + Context context) throws IOException, InterruptedException { + List<Centroid> intermediateCentroids; + // There might be too many intermediate centroids to fit into memory, in which case, we run another pass + // of StreamingKMeans to collapse the clusters further. + if (conf.getBoolean(StreamingKMeansDriver.REDUCE_STREAMING_KMEANS, false)) { + intermediateCentroids = Lists.newArrayList( + new StreamingKMeansThread(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() { + @Override + public Centroid apply(CentroidWritable input) { + Preconditions.checkNotNull(input); + return input.getCentroid().clone(); + } + }), conf).call()); + } else { + intermediateCentroids = centroidWritablesToList(centroids); + } + + int index = 0; + for (Vector centroid : getBestCentroids(intermediateCentroids, conf)) { + context.write(new IntWritable(index), new CentroidWritable((Centroid) centroid)); + ++index; + } + } + + public static List<Centroid> centroidWritablesToList(Iterable<CentroidWritable> centroids) { + // A new list must be created because Hadoop iterators mutate the contents of the Writable in + // place, without allocating new references when iterating through the centroids Iterable. + return Lists.newArrayList(Iterables.transform(centroids, new Function<CentroidWritable, Centroid>() { + @Override + public Centroid apply(CentroidWritable input) { + Preconditions.checkNotNull(input); + return input.getCentroid().clone(); + } + })); + } + + public static Iterable<Vector> getBestCentroids(List<Centroid> centroids, Configuration conf) { + + if (log.isInfoEnabled()) { + log.info("Number of Centroids: {}", centroids.size()); + } + + int numClusters = conf.getInt(DefaultOptionCreator.NUM_CLUSTERS_OPTION, 1); + int maxNumIterations = conf.getInt(StreamingKMeansDriver.MAX_NUM_ITERATIONS, 10); + float trimFraction = conf.getFloat(StreamingKMeansDriver.TRIM_FRACTION, 0.9f); + boolean kMeansPlusPlusInit = !conf.getBoolean(StreamingKMeansDriver.RANDOM_INIT, false); + boolean correctWeights = !conf.getBoolean(StreamingKMeansDriver.IGNORE_WEIGHTS, false); + float testProbability = conf.getFloat(StreamingKMeansDriver.TEST_PROBABILITY, 0.1f); + int numRuns = conf.getInt(StreamingKMeansDriver.NUM_BALLKMEANS_RUNS, 3); + + BallKMeans ballKMeansCluster = new BallKMeans(StreamingKMeansUtilsMR.searcherFromConfiguration(conf), + numClusters, maxNumIterations, trimFraction, kMeansPlusPlusInit, correctWeights, testProbability, numRuns); + return ballKMeansCluster.cluster(centroids); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java new file mode 100644 index 0000000..24cc1db --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansThread.java @@ -0,0 +1,92 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.Callable; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.clustering.streaming.cluster.StreamingKMeans; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterable; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class StreamingKMeansThread implements Callable<Iterable<Centroid>> { + private static final Logger log = LoggerFactory.getLogger(StreamingKMeansThread.class); + + private static final int NUM_ESTIMATE_POINTS = 1000; + + private final Configuration conf; + private final Iterable<Centroid> dataPoints; + + public StreamingKMeansThread(Path input, Configuration conf) { + this(StreamingKMeansUtilsMR.getCentroidsFromVectorWritable( + new SequenceFileValueIterable<VectorWritable>(input, false, conf)), conf); + } + + public StreamingKMeansThread(Iterable<Centroid> dataPoints, Configuration conf) { + this.dataPoints = dataPoints; + this.conf = conf; + } + + @Override + public Iterable<Centroid> call() { + UpdatableSearcher searcher = StreamingKMeansUtilsMR.searcherFromConfiguration(conf); + int numClusters = conf.getInt(StreamingKMeansDriver.ESTIMATED_NUM_MAP_CLUSTERS, 1); + double estimateDistanceCutoff = conf.getFloat(StreamingKMeansDriver.ESTIMATED_DISTANCE_CUTOFF, + StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF); + + Iterator<Centroid> dataPointsIterator = dataPoints.iterator(); + + if (estimateDistanceCutoff == StreamingKMeansDriver.INVALID_DISTANCE_CUTOFF) { + List<Centroid> estimatePoints = new ArrayList<>(NUM_ESTIMATE_POINTS); + while (dataPointsIterator.hasNext() && estimatePoints.size() < NUM_ESTIMATE_POINTS) { + Centroid centroid = dataPointsIterator.next(); + estimatePoints.add(centroid); + } + + if (log.isInfoEnabled()) { + log.info("Estimated Points: {}", estimatePoints.size()); + } + estimateDistanceCutoff = ClusteringUtils.estimateDistanceCutoff(estimatePoints, searcher.getDistanceMeasure()); + } + + StreamingKMeans streamingKMeans = new StreamingKMeans(searcher, numClusters, estimateDistanceCutoff); + + // datapointsIterator could be empty if no estimate distance was initially provided + // hence creating the iterator again here for the clustering + if (!dataPointsIterator.hasNext()) { + dataPointsIterator = dataPoints.iterator(); + } + + while (dataPointsIterator.hasNext()) { + streamingKMeans.cluster(dataPointsIterator.next()); + } + + streamingKMeans.reindexCentroids(); + return streamingKMeans; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java new file mode 100644 index 0000000..f00cf56 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/StreamingKMeansUtilsMR.java @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import java.io.IOException; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.RandomAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.neighborhood.BruteSearch; +import org.apache.mahout.math.neighborhood.FastProjectionSearch; +import org.apache.mahout.math.neighborhood.LocalitySensitiveHashSearch; +import org.apache.mahout.math.neighborhood.ProjectionSearch; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; + +public final class StreamingKMeansUtilsMR { + + private StreamingKMeansUtilsMR() { + } + + /** + * Instantiates a searcher from a given configuration. + * @param conf the configuration + * @return the instantiated searcher + * @throws RuntimeException if the distance measure class cannot be instantiated + * @throws IllegalStateException if an unknown searcher class was requested + */ + public static UpdatableSearcher searcherFromConfiguration(Configuration conf) { + DistanceMeasure distanceMeasure; + String distanceMeasureClass = conf.get(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); + try { + distanceMeasure = (DistanceMeasure) Class.forName(distanceMeasureClass).getConstructor().newInstance(); + } catch (Exception e) { + throw new RuntimeException("Failed to instantiate distanceMeasure", e); + } + + int numProjections = conf.getInt(StreamingKMeansDriver.NUM_PROJECTIONS_OPTION, 20); + int searchSize = conf.getInt(StreamingKMeansDriver.SEARCH_SIZE_OPTION, 10); + + String searcherClass = conf.get(StreamingKMeansDriver.SEARCHER_CLASS_OPTION); + + if (searcherClass.equals(BruteSearch.class.getName())) { + return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class, + new Class[]{DistanceMeasure.class}, new Object[]{distanceMeasure}); + } else if (searcherClass.equals(FastProjectionSearch.class.getName()) + || searcherClass.equals(ProjectionSearch.class.getName())) { + return ClassUtils.instantiateAs(searcherClass, UpdatableSearcher.class, + new Class[]{DistanceMeasure.class, int.class, int.class}, + new Object[]{distanceMeasure, numProjections, searchSize}); + } else if (searcherClass.equals(LocalitySensitiveHashSearch.class.getName())) { + return ClassUtils.instantiateAs(searcherClass, LocalitySensitiveHashSearch.class, + new Class[]{DistanceMeasure.class, int.class}, + new Object[]{distanceMeasure, searchSize}); + } else { + throw new IllegalStateException("Unknown class instantiation requested"); + } + } + + /** + * Returns an Iterable of centroids from an Iterable of VectorWritables by creating a new Centroid containing + * a RandomAccessSparseVector as a delegate for each VectorWritable. + * @param inputIterable VectorWritable Iterable to get Centroids from + * @return the new Centroids + */ + public static Iterable<Centroid> getCentroidsFromVectorWritable(Iterable<VectorWritable> inputIterable) { + return Iterables.transform(inputIterable, new Function<VectorWritable, Centroid>() { + private int numVectors = 0; + @Override + public Centroid apply(VectorWritable input) { + Preconditions.checkNotNull(input); + return new Centroid(numVectors++, new RandomAccessSparseVector(input.get()), 1); + } + }); + } + + /** + * Returns an Iterable of Centroid from an Iterable of Vector by either casting each Vector to Centroid (if the + * instance extends Centroid) or create a new Centroid based on that Vector. + * The implicit expectation is that the input will not have interleaving types of vectors. Otherwise, the numbering + * of new Centroids will become invalid. + * @param input Iterable of Vectors to cast + * @return the new Centroids + */ + public static Iterable<Centroid> castVectorsToCentroids(Iterable<Vector> input) { + return Iterables.transform(input, new Function<Vector, Centroid>() { + private int numVectors = 0; + @Override + public Centroid apply(Vector input) { + Preconditions.checkNotNull(input); + if (input instanceof Centroid) { + return (Centroid) input; + } else { + return new Centroid(numVectors++, input, 1); + } + } + }); + } + + /** + * Writes centroids to a sequence file. + * @param centroids the centroids to write. + * @param path the path of the output file. + * @param conf the configuration for the HDFS to write the file to. + * @throws java.io.IOException + */ + public static void writeCentroidsToSequenceFile(Iterable<Centroid> centroids, Path path, Configuration conf) + throws IOException { + try (SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf, + path, IntWritable.class, CentroidWritable.class)) { + int i = 0; + for (Centroid centroid : centroids) { + writer.append(new IntWritable(i++), new CentroidWritable(centroid)); + } + } + } + + public static void writeVectorsToSequenceFile(Iterable<? extends Vector> datapoints, Path path, Configuration conf) + throws IOException { + try (SequenceFile.Writer writer = SequenceFile.createWriter(FileSystem.get(conf), conf, + path, IntWritable.class, VectorWritable.class)){ + int i = 0; + for (Vector vector : datapoints) { + writer.append(new IntWritable(i++), new VectorWritable(vector)); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java new file mode 100644 index 0000000..d7ca554 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/streaming/tools/ResplitSequenceFiles.java @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.tools; + +import java.io.IOException; +import java.io.OutputStreamWriter; +import java.io.PrintWriter; +import java.util.Iterator; + +import com.google.common.collect.Iterables; +import org.apache.commons.cli2.CommandLine; +import org.apache.commons.cli2.Group; +import org.apache.commons.cli2.Option; +import org.apache.commons.cli2.builder.ArgumentBuilder; +import org.apache.commons.cli2.builder.DefaultOptionBuilder; +import org.apache.commons.cli2.builder.GroupBuilder; +import org.apache.commons.cli2.commandline.Parser; +import org.apache.commons.cli2.util.HelpFormatter; +import org.apache.commons.io.Charsets; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; + +public class ResplitSequenceFiles { + + private String inputFile; + private String outputFileBase; + private int numSplits; + + private Configuration conf; + private FileSystem fs; + + private ResplitSequenceFiles() {} + + private void writeSplit(Iterator<Pair<Writable, Writable>> inputIterator, + int numSplit, int numEntriesPerSplit) throws IOException { + SequenceFile.Writer splitWriter = null; + for (int j = 0; j < numEntriesPerSplit; ++j) { + Pair<Writable, Writable> item = inputIterator.next(); + if (splitWriter == null) { + splitWriter = SequenceFile.createWriter(fs, conf, + new Path(outputFileBase + "-" + numSplit), item.getFirst().getClass(), item.getSecond().getClass()); + } + splitWriter.append(item.getFirst(), item.getSecond()); + } + if (splitWriter != null) { + splitWriter.close(); + } + } + + private void run(PrintWriter printWriter) throws IOException { + conf = new Configuration(); + SequenceFileDirIterable<Writable, Writable> inputIterable = new + SequenceFileDirIterable<>(new Path(inputFile), PathType.LIST, conf); + fs = FileSystem.get(conf); + + int numEntries = Iterables.size(inputIterable); + int numEntriesPerSplit = numEntries / numSplits; + int numEntriesLastSplit = numEntriesPerSplit + numEntries - numEntriesPerSplit * numSplits; + Iterator<Pair<Writable, Writable>> inputIterator = inputIterable.iterator(); + + printWriter.printf("Writing %d splits\n", numSplits); + for (int i = 0; i < numSplits - 1; ++i) { + printWriter.printf("Writing split %d\n", i); + writeSplit(inputIterator, i, numEntriesPerSplit); + } + printWriter.printf("Writing split %d\n", numSplits - 1); + writeSplit(inputIterator, numSplits - 1, numEntriesLastSplit); + } + + private boolean parseArgs(String[] args) { + DefaultOptionBuilder builder = new DefaultOptionBuilder(); + + Option help = builder.withLongName("help").withDescription("print this list").create(); + + ArgumentBuilder argumentBuilder = new ArgumentBuilder(); + Option inputFileOption = builder.withLongName("input") + .withShortName("i") + .withRequired(true) + .withArgument(argumentBuilder.withName("input").withMaximum(1).create()) + .withDescription("what the base folder for sequence files is (they all must have the same key/value type") + .create(); + + Option outputFileOption = builder.withLongName("output") + .withShortName("o") + .withRequired(true) + .withArgument(argumentBuilder.withName("output").withMaximum(1).create()) + .withDescription("the base name of the file split that the files will be split it; the i'th split has the " + + "suffix -i") + .create(); + + Option numSplitsOption = builder.withLongName("numSplits") + .withShortName("ns") + .withRequired(true) + .withArgument(argumentBuilder.withName("numSplits").withMaximum(1).create()) + .withDescription("how many splits to use for the given files") + .create(); + + Group normalArgs = new GroupBuilder() + .withOption(help) + .withOption(inputFileOption) + .withOption(outputFileOption) + .withOption(numSplitsOption) + .create(); + + Parser parser = new Parser(); + parser.setHelpOption(help); + parser.setHelpTrigger("--help"); + parser.setGroup(normalArgs); + parser.setHelpFormatter(new HelpFormatter(" ", "", " ", 130)); + CommandLine cmdLine = parser.parseAndHelp(args); + + if (cmdLine == null) { + return false; + } + + inputFile = (String) cmdLine.getValue(inputFileOption); + outputFileBase = (String) cmdLine.getValue(outputFileOption); + numSplits = Integer.parseInt((String) cmdLine.getValue(numSplitsOption)); + return true; + } + + public static void main(String[] args) throws IOException { + ResplitSequenceFiles runner = new ResplitSequenceFiles(); + if (runner.parseArgs(args)) { + runner.run(new PrintWriter(new OutputStreamWriter(System.out, Charsets.UTF_8), true)); + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java new file mode 100644 index 0000000..11bc34a --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/PathDirectory.java @@ -0,0 +1,94 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.topdown; + +import java.io.File; + +import org.apache.hadoop.fs.Path; + +/** + * Contains list of all internal paths used in top down clustering. + */ +public final class PathDirectory { + + public static final String TOP_LEVEL_CLUSTER_DIRECTORY = "topLevelCluster"; + public static final String POST_PROCESS_DIRECTORY = "clusterPostProcessed"; + public static final String CLUSTERED_POINTS_DIRECTORY = "clusteredPoints"; + public static final String BOTTOM_LEVEL_CLUSTER_DIRECTORY = "bottomLevelCluster"; + + private PathDirectory() { + } + + /** + * All output of top level clustering is stored in output directory/topLevelCluster. + * + * @param output + * the output path of clustering. + * @return The top level Cluster Directory. + */ + public static Path getTopLevelClusterPath(Path output) { + return new Path(output + File.separator + TOP_LEVEL_CLUSTER_DIRECTORY); + } + + /** + * The output of top level clusters is post processed and kept in this path. + * + * @param outputPathProvidedByUser + * the output path of clustering. + * @return the path where the output of top level cluster post processor is kept. + */ + public static Path getClusterPostProcessorOutputDirectory(Path outputPathProvidedByUser) { + return new Path(outputPathProvidedByUser + File.separator + POST_PROCESS_DIRECTORY); + } + + /** + * The top level clustered points before post processing is generated here. + * + * @param output + * the output path of clustering. + * @return the clustered points directory + */ + public static Path getClusterOutputClusteredPoints(Path output) { + return new Path(output + File.separator + CLUSTERED_POINTS_DIRECTORY + File.separator, "*"); + } + + /** + * Each cluster produced by top level clustering is processed in output/"bottomLevelCluster"/clusterId. + * + * @param output + * @param clusterId + * @return the bottom level clustering path. + */ + public static Path getBottomLevelClusterPath(Path output, String clusterId) { + return new Path(output + File.separator + BOTTOM_LEVEL_CLUSTER_DIRECTORY + File.separator + clusterId); + } + + /** + * Each clusters path name is its clusterId. The vectors reside in separate files inside it. + * + * @param clusterPostProcessorOutput + * the path of cluster post processor output. + * @param clusterId + * the id of the cluster. + * @return the cluster path for cluster id. + */ + public static Path getClusterPathForClusterId(Path clusterPostProcessorOutput, String clusterId) { + return new Path(clusterPostProcessorOutput + File.separator + clusterId); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java new file mode 100644 index 0000000..d0563fd --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterCountReader.java @@ -0,0 +1,103 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +/** + * Reads the number of clusters produced by the clustering algorithm. + */ +public final class ClusterCountReader { + + private ClusterCountReader() { + } + + /** + * Reads the number of clusters present by reading the clusters-*-final file. + * + * @param clusterOutputPath The output path provided to the clustering algorithm. + * @param conf The hadoop configuration. + * @return the number of final clusters. + */ + public static int getNumberOfClusters(Path clusterOutputPath, Configuration conf) throws IOException { + FileSystem fileSystem = clusterOutputPath.getFileSystem(conf); + FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter()); + int numberOfClusters = 0; + Iterator<?> it = new SequenceFileDirValueIterator<>(clusterFiles[0].getPath(), + PathType.LIST, + PathFilters.partFilter(), + null, + true, + conf); + while (it.hasNext()) { + it.next(); + numberOfClusters++; + } + return numberOfClusters; + } + + /** + * Generates a list of all cluster ids by reading the clusters-*-final file. + * + * @param clusterOutputPath The output path provided to the clustering algorithm. + * @param conf The hadoop configuration. + * @return An ArrayList containing the final cluster ids. + */ + public static Map<Integer, Integer> getClusterIDs(Path clusterOutputPath, Configuration conf, boolean keyIsClusterId) + throws IOException { + Map<Integer, Integer> clusterIds = new HashMap<>(); + FileSystem fileSystem = clusterOutputPath.getFileSystem(conf); + FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter()); + //System.out.println("LOOK HERE: " + clusterOutputPath); + Iterator<ClusterWritable> it = new SequenceFileDirValueIterator<>(clusterFiles[0].getPath(), + PathType.LIST, + PathFilters.partFilter(), + null, + true, + conf); + int i = 0; + while (it.hasNext()) { + Integer key; + Integer value; + if (keyIsClusterId) { // key is the cluster id, value is i, the index we will use + key = it.next().getValue().getId(); + value = i; + } else { + key = i; + value = it.next().getValue().getId(); + } + clusterIds.put(key, value); + i++; + } + return clusterIds; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java new file mode 100644 index 0000000..ded76ad --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessor.java @@ -0,0 +1,139 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.LongWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.SequenceFile.Writer; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.clustering.topdown.PathDirectory; +import org.apache.mahout.common.IOUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirIterable; +import org.apache.mahout.math.VectorWritable; + +/** + * This class reads the output of any clustering algorithm, and, creates separate directories for different + * clusters. Each cluster directory's name is its clusterId. Each and every point is written in the cluster + * directory associated with that point. + * <p/> + * This class incorporates a sequential algorithm and is appropriate for use for data which has been clustered + * sequentially. + * <p/> + * The sequential and non sequential version, both are being used from {@link ClusterOutputPostProcessorDriver}. + */ +public final class ClusterOutputPostProcessor { + + private Path clusteredPoints; + private final FileSystem fileSystem; + private final Configuration conf; + private final Path clusterPostProcessorOutput; + private final Map<String, Path> postProcessedClusterDirectories = new HashMap<>(); + private long uniqueVectorId = 0L; + private final Map<String, SequenceFile.Writer> writersForClusters; + + public ClusterOutputPostProcessor(Path clusterOutputToBeProcessed, + Path output, + Configuration hadoopConfiguration) throws IOException { + this.clusterPostProcessorOutput = output; + this.clusteredPoints = PathDirectory.getClusterOutputClusteredPoints(clusterOutputToBeProcessed); + this.conf = hadoopConfiguration; + this.writersForClusters = new HashMap<>(); + fileSystem = clusteredPoints.getFileSystem(conf); + } + + /** + * This method takes the clustered points output by the clustering algorithms as input and writes them into + * their respective clusters. + */ + public void process() throws IOException { + createPostProcessDirectory(); + for (Pair<?, WeightedVectorWritable> record + : new SequenceFileDirIterable<Writable, WeightedVectorWritable>(clusteredPoints, PathType.GLOB, PathFilters.partFilter(), + null, false, conf)) { + String clusterId = record.getFirst().toString().trim(); + putVectorInRespectiveCluster(clusterId, record.getSecond()); + } + IOUtils.close(writersForClusters.values()); + writersForClusters.clear(); + } + + /** + * Creates the directory to put post processed clusters. + */ + private void createPostProcessDirectory() throws IOException { + if (!fileSystem.exists(clusterPostProcessorOutput) + && !fileSystem.mkdirs(clusterPostProcessorOutput)) { + throw new IOException("Error creating cluster post processor directory"); + } + } + + /** + * Finds out the cluster directory of the vector and writes it into the specified cluster. + */ + private void putVectorInRespectiveCluster(String clusterId, WeightedVectorWritable point) throws IOException { + Writer writer = findWriterForVector(clusterId); + postProcessedClusterDirectories.put(clusterId, + PathDirectory.getClusterPathForClusterId(clusterPostProcessorOutput, clusterId)); + writeVectorToCluster(writer, point); + } + + /** + * Finds out the path in cluster where the point is supposed to be written. + */ + private Writer findWriterForVector(String clusterId) throws IOException { + Path clusterDirectory = PathDirectory.getClusterPathForClusterId(clusterPostProcessorOutput, clusterId); + Writer writer = writersForClusters.get(clusterId); + if (writer == null) { + Path pathToWrite = new Path(clusterDirectory, new Path("part-m-0")); + writer = new Writer(fileSystem, conf, pathToWrite, LongWritable.class, VectorWritable.class); + writersForClusters.put(clusterId, writer); + } + return writer; + } + + /** + * Writes vector to the cluster directory. + */ + private void writeVectorToCluster(Writer writer, WeightedVectorWritable point) throws IOException { + writer.append(new LongWritable(uniqueVectorId++), new VectorWritable(point.getVector())); + writer.sync(); + } + + /** + * @return the set of all post processed cluster paths. + */ + public Map<String, Path> getPostProcessedClusterDirectories() { + return postProcessedClusterDirectories; + } + + public void setClusteredPoints(Path clusteredPoints) { + this.clusteredPoints = clusteredPoints; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java new file mode 100644 index 0000000..82a3071 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorDriver.java @@ -0,0 +1,182 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterator; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; + +/** + * Post processes the output of clustering algorithms and groups them into respective clusters. Ideal to be + * used for top down clustering. It can also be used if the clustering output needs to be grouped into their + * respective clusters. + */ +public final class ClusterOutputPostProcessorDriver extends AbstractJob { + + /** + * CLI to run clustering post processor. The input to post processor is the ouput path specified to the + * clustering. + */ + @Override + public int run(String[] args) throws Exception { + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.methodOption().create()); + addOption(DefaultOptionCreator.overwriteOption().create()); + + if (parseArguments(args) == null) { + return -1; + } + Path input = getInputPath(); + Path output = getOutputPath(); + + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase( + DefaultOptionCreator.SEQUENTIAL_METHOD); + run(input, output, runSequential); + return 0; + + } + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new ClusterOutputPostProcessorDriver(), args); + } + + /** + * Post processes the output of clustering algorithms and groups them into respective clusters. Each + * cluster's vectors are written into a directory named after its clusterId. + * + * @param input The output path provided to the clustering algorithm, whose would be post processed. Hint: The + * path of the directory containing clusters-*-final and clusteredPoints. + * @param output The post processed data would be stored at this path. + * @param runSequential If set to true, post processes it sequentially, else, uses. MapReduce. Hint: If the clustering + * was done sequentially, make it sequential, else vice versa. + */ + public static void run(Path input, Path output, boolean runSequential) throws IOException, + InterruptedException, + ClassNotFoundException { + if (runSequential) { + postProcessSeq(input, output); + } else { + Configuration conf = new Configuration(); + postProcessMR(conf, input, output); + movePartFilesToRespectiveDirectories(conf, output); + } + + } + + /** + * Process Sequentially. Reads the vectors one by one, and puts them into respective directory, named after + * their clusterId. + * + * @param input The output path provided to the clustering algorithm, whose would be post processed. Hint : The + * path of the directory containing clusters-*-final and clusteredPoints. + * @param output The post processed data would be stored at this path. + */ + private static void postProcessSeq(Path input, Path output) throws IOException { + ClusterOutputPostProcessor clusterOutputPostProcessor = new ClusterOutputPostProcessor(input, output, + new Configuration()); + clusterOutputPostProcessor.process(); + } + + /** + * Process as a map reduce job. The numberOfReduceTasks is set to the number of clusters present in the + * output. So that each cluster's vector is written in its own part file. + * + * @param conf The hadoop configuration. + * @param input The output path provided to the clustering algorithm, whose would be post processed. Hint : The + * path of the directory containing clusters-*-final and clusteredPoints. + * @param output The post processed data would be stored at this path. + */ + private static void postProcessMR(Configuration conf, Path input, Path output) throws IOException, + InterruptedException, + ClassNotFoundException { + System.out.println("WARNING: If you are running in Hadoop local mode, please use the --sequential option, " + + "as the MapReduce option will not work properly"); + int numberOfClusters = ClusterCountReader.getNumberOfClusters(input, conf); + conf.set("clusterOutputPath", input.toString()); + Job job = new Job(conf, "ClusterOutputPostProcessor Driver running over input: " + input); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(ClusterOutputPostProcessorMapper.class); + job.setMapOutputKeyClass(IntWritable.class); + job.setMapOutputValueClass(VectorWritable.class); + job.setReducerClass(ClusterOutputPostProcessorReducer.class); + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(VectorWritable.class); + job.setNumReduceTasks(numberOfClusters); + job.setJarByClass(ClusterOutputPostProcessorDriver.class); + + FileInputFormat.addInputPath(job, new Path(input, new Path("clusteredPoints"))); + FileOutputFormat.setOutputPath(job, output); + if (!job.waitForCompletion(true)) { + throw new InterruptedException("ClusterOutputPostProcessor Job failed processing " + input); + } + } + + /** + * The mapreduce version of the post processor writes different clusters into different part files. This + * method reads the part files and moves them into directories named after their clusterIds. + * + * @param conf The hadoop configuration. + * @param output The post processed data would be stored at this path. + */ + private static void movePartFilesToRespectiveDirectories(Configuration conf, Path output) throws IOException { + FileSystem fileSystem = output.getFileSystem(conf); + for (FileStatus fileStatus : fileSystem.listStatus(output, PathFilters.partFilter())) { + SequenceFileIterator<Writable, Writable> it = + new SequenceFileIterator<>(fileStatus.getPath(), true, conf); + if (it.hasNext()) { + renameFile(it.next().getFirst(), fileStatus, conf); + } + it.close(); + } + } + + /** + * Using @FileSystem rename method to move the file. + */ + private static void renameFile(Writable key, FileStatus fileStatus, Configuration conf) throws IOException { + Path path = fileStatus.getPath(); + FileSystem fileSystem = path.getFileSystem(conf); + Path subDir = new Path(key.toString()); + Path renameTo = new Path(path.getParent(), subDir); + fileSystem.mkdirs(renameTo); + fileSystem.rename(path, renameTo); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java new file mode 100644 index 0000000..6834362 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorMapper.java @@ -0,0 +1,58 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.util.Map; + +/** + * Mapper for post processing cluster output. + */ +public class ClusterOutputPostProcessorMapper extends + Mapper<IntWritable, WeightedVectorWritable, IntWritable, VectorWritable> { + + private Map<Integer, Integer> newClusterMappings; + private VectorWritable outputVector; + + //read the current cluster ids, and populate the cluster mapping hash table + @Override + public void setup(Context context) throws IOException { + Configuration conf = context.getConfiguration(); + //this give the clusters-x-final directory where the cluster ids can be read + Path clusterOutputPath = new Path(conf.get("clusterOutputPath")); + //we want the key to be the cluster id, the value to be the index + newClusterMappings = ClusterCountReader.getClusterIDs(clusterOutputPath, conf, true); + outputVector = new VectorWritable(); + } + + @Override + public void map(IntWritable key, WeightedVectorWritable val, Context context) + throws IOException, InterruptedException { + // by pivoting on the cluster mapping value, we can make sure that each unique cluster goes to it's own reducer, + // since they are numbered from 0 to k-1, where k is the number of clusters + outputVector.set(val.getVector()); + context.write(new IntWritable(newClusterMappings.get(key.get())), outputVector); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/410ed16a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java new file mode 100644 index 0000000..58dada4 --- /dev/null +++ b/community/mahout-mr/mr/src/main/java/org/apache/mahout/clustering/topdown/postprocessor/ClusterOutputPostProcessorReducer.java @@ -0,0 +1,62 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.topdown.postprocessor; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.math.VectorWritable; + +import java.io.IOException; +import java.util.Map; + +/** + * Reducer for post processing cluster output. + */ +public class ClusterOutputPostProcessorReducer + extends Reducer<IntWritable, VectorWritable, IntWritable, VectorWritable> { + + private Map<Integer, Integer> reverseClusterMappings; + + //read the current cluster ids, and populate the hash cluster mapping hash table + @Override + public void setup(Context context) throws IOException { + Configuration conf = context.getConfiguration(); + Path clusterOutputPath = new Path(conf.get("clusterOutputPath")); + //we want to the key to be the index, the value to be the cluster id + reverseClusterMappings = ClusterCountReader.getClusterIDs(clusterOutputPath, conf, false); + } + + /** + * The key is the remapped cluster id and the values contains the vectors in that cluster. + */ + @Override + protected void reduce(IntWritable key, Iterable<VectorWritable> values, Context context) throws IOException, + InterruptedException { + //remap the cluster back to its original id + //and then output the vectors with their correct + //cluster id. + IntWritable outKey = new IntWritable(reverseClusterMappings.get(key.get())); + System.out.println(outKey + " this: " + this); + for (VectorWritable value : values) { + context.write(outKey, value); + } + } + +}
