http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java new file mode 100644 index 0000000..c42ab70 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VectorMatrixMultiplicationJob.java @@ -0,0 +1,139 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.Functions; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; + +/** + * <p>This class handles the three-way multiplication of the digonal matrix + * and the Markov transition matrix inherent in the Eigencuts algorithm. + * The equation takes the form:</p> + * + * {@code W = D^(1/2) * M * D^(1/2)} + * + * <p>Since the diagonal matrix D has only n non-zero elements, it is represented + * as a dense vector in this job, rather than a full n-by-n matrix. This job + * performs the multiplications and returns the new DRM. + */ +public final class VectorMatrixMultiplicationJob { + + private VectorMatrixMultiplicationJob() { + } + + /** + * Invokes the job. + * @param markovPath Path to the markov DRM's sequence files + */ + public static DistributedRowMatrix runJob(Path markovPath, Vector diag, Path outputPath) + throws IOException, ClassNotFoundException, InterruptedException { + + return runJob(markovPath, diag, outputPath, new Path(outputPath, "tmp")); + } + + public static DistributedRowMatrix runJob(Path markovPath, Vector diag, Path outputPath, Path tmpPath) + throws IOException, ClassNotFoundException, InterruptedException { + + // set up the serialization of the diagonal vector + Configuration conf = new Configuration(); + FileSystem fs = FileSystem.get(markovPath.toUri(), conf); + markovPath = fs.makeQualified(markovPath); + outputPath = fs.makeQualified(outputPath); + Path vectorOutputPath = new Path(outputPath.getParent(), "vector"); + VectorCache.save(new IntWritable(Keys.DIAGONAL_CACHE_INDEX), diag, vectorOutputPath, conf); + + // set up the job itself + Job job = new Job(conf, "VectorMatrixMultiplication"); + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(VectorWritable.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(VectorMatrixMultiplicationMapper.class); + job.setNumReduceTasks(0); + + FileInputFormat.addInputPath(job, markovPath); + FileOutputFormat.setOutputPath(job, outputPath); + + job.setJarByClass(VectorMatrixMultiplicationJob.class); + + boolean succeeded = job.waitForCompletion(true); + if (!succeeded) { + throw new IllegalStateException("Job failed!"); + } + + // build the resulting DRM from the results + return new DistributedRowMatrix(outputPath, tmpPath, + diag.size(), diag.size()); + } + + public static class VectorMatrixMultiplicationMapper + extends Mapper<IntWritable, VectorWritable, IntWritable, VectorWritable> { + + private Vector diagonal; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + // read in the diagonal vector from the distributed cache + super.setup(context); + Configuration config = context.getConfiguration(); + diagonal = VectorCache.load(config); + if (diagonal == null) { + throw new IOException("No vector loaded from cache!"); + } + if (!(diagonal instanceof DenseVector)) { + diagonal = new DenseVector(diagonal); + } + } + + @Override + protected void map(IntWritable key, VectorWritable row, Context ctx) + throws IOException, InterruptedException { + + for (Vector.Element e : row.get().all()) { + double dii = Functions.SQRT.apply(diagonal.get(key.get())); + double djj = Functions.SQRT.apply(diagonal.get(e.index())); + double mij = e.get(); + e.set(dii * mij * djj); + } + ctx.write(key, row); + } + + /** + * Performs the setup of the Mapper. Used by unit tests. + * @param diag + */ + void setup(Vector diag) { + this.diagonal = diag; + } + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java new file mode 100644 index 0000000..0d70cac --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/VertexWritable.java @@ -0,0 +1,101 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; + +/** + * Represents a vertex within the affinity graph for Eigencuts. + */ +public class VertexWritable implements Writable { + + /** the row */ + private int i; + + /** the column */ + private int j; + + /** the value at this vertex */ + private double value; + + /** an extra type delimeter, can probably be null */ + private String type; + + public VertexWritable() { + } + + public VertexWritable(int i, int j, double v, String t) { + this.i = i; + this.j = j; + this.value = v; + this.type = t; + } + + public int getRow() { + return i; + } + + public void setRow(int i) { + this.i = i; + } + + public int getCol() { + return j; + } + + public void setCol(int j) { + this.j = j; + } + + public double getValue() { + return value; + } + + public void setValue(double v) { + this.value = v; + } + + public String getType() { + return type; + } + + public void setType(String t) { + this.type = t; + } + + @Override + public void readFields(DataInput arg0) throws IOException { + this.i = arg0.readInt(); + this.j = arg0.readInt(); + this.value = arg0.readDouble(); + this.type = arg0.readUTF(); + } + + @Override + public void write(DataOutput arg0) throws IOException { + arg0.writeInt(i); + arg0.writeInt(j); + arg0.writeDouble(value); + arg0.writeUTF(type); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java new file mode 100644 index 0000000..3ce94dc --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/EigenSeedGenerator.java @@ -0,0 +1,120 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral.kmeans; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.clustering.kmeans.Kluster; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Given an Input Path containing a {@link org.apache.hadoop.io.SequenceFile}, select k vectors and write them to the + * output file as a {@link org.apache.mahout.clustering.kmeans.Kluster} representing the initial centroid to use. The + * selection criterion is the rows with max value in that respective column + */ +public final class EigenSeedGenerator { + + private static final Logger log = LoggerFactory.getLogger(EigenSeedGenerator.class); + + public static final String K = "k"; + + private EigenSeedGenerator() {} + + public static Path buildFromEigens(Configuration conf, Path input, Path output, int k, DistanceMeasure measure) + throws IOException { + // delete the output directory + FileSystem fs = FileSystem.get(output.toUri(), conf); + HadoopUtil.delete(conf, output); + Path outFile = new Path(output, "part-eigenSeed"); + boolean newFile = fs.createNewFile(outFile); + if (newFile) { + Path inputPathPattern; + + if (fs.getFileStatus(input).isDir()) { + inputPathPattern = new Path(input, "*"); + } else { + inputPathPattern = input; + } + + FileStatus[] inputFiles = fs.globStatus(inputPathPattern, PathFilters.logsCRCFilter()); + Map<Integer,Double> maxEigens = new HashMap<>(k); // store + // max + // value + // of + // each + // column + Map<Integer,Text> chosenTexts = new HashMap<>(k); + Map<Integer,ClusterWritable> chosenClusters = new HashMap<>(k); + + for (FileStatus fileStatus : inputFiles) { + if (!fileStatus.isDir()) { + for (Pair<Writable,VectorWritable> record : new SequenceFileIterable<Writable,VectorWritable>( + fileStatus.getPath(), true, conf)) { + Writable key = record.getFirst(); + VectorWritable value = record.getSecond(); + + for (Vector.Element e : value.get().nonZeroes()) { + int index = e.index(); + double v = Math.abs(e.get()); + + if (!maxEigens.containsKey(index) || v > maxEigens.get(index)) { + maxEigens.put(index, v); + Text newText = new Text(key.toString()); + chosenTexts.put(index, newText); + Kluster newCluster = new Kluster(value.get(), index, measure); + newCluster.observe(value.get(), 1); + ClusterWritable clusterWritable = new ClusterWritable(); + clusterWritable.setValue(newCluster); + chosenClusters.put(index, clusterWritable); + } + } + } + } + } + + try (SequenceFile.Writer writer = + SequenceFile.createWriter(fs, conf, outFile, Text.class, ClusterWritable.class)){ + for (Integer key : maxEigens.keySet()) { + writer.append(chosenTexts.get(key), chosenClusters.get(key)); + } + log.info("EigenSeedGenerator:: Wrote {} Klusters to {}", chosenTexts.size(), outFile); + } + } + + return outFile; + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java new file mode 100644 index 0000000..427de91 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/spectral/kmeans/SpectralKMeansDriver.java @@ -0,0 +1,243 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * <p/> + * http://www.apache.org/licenses/LICENSE-2.0 + * <p/> + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.spectral.kmeans; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.WeightedVectorWritable; +import org.apache.mahout.clustering.kmeans.KMeansDriver; +import org.apache.mahout.clustering.spectral.AffinityMatrixInputJob; +import org.apache.mahout.clustering.spectral.MatrixDiagonalizeJob; +import org.apache.mahout.clustering.spectral.UnitVectorizerJob; +import org.apache.mahout.clustering.spectral.VectorMatrixMultiplicationJob; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.hadoop.DistributedRowMatrix; +import org.apache.mahout.math.hadoop.stochasticsvd.SSVDSolver; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Performs spectral k-means clustering on the top k eigenvectors of the input affinity matrix. + */ +public class SpectralKMeansDriver extends AbstractJob { + private static final Logger log = LoggerFactory.getLogger(SpectralKMeansDriver.class); + + public static final int REDUCERS = 10; + public static final int BLOCKHEIGHT = 30000; + public static final int OVERSAMPLING = 15; + public static final int POWERITERS = 0; + + public static void main(String[] args) throws Exception { + ToolRunner.run(new SpectralKMeansDriver(), args); + } + + @Override + public int run(String[] arg0) throws Exception { + + Configuration conf = getConf(); + addInputOption(); + addOutputOption(); + addOption("dimensions", "d", "Square dimensions of affinity matrix", true); + addOption("clusters", "k", "Number of clusters and top eigenvectors", true); + addOption(DefaultOptionCreator.distanceMeasureOption().create()); + addOption(DefaultOptionCreator.convergenceOption().create()); + addOption(DefaultOptionCreator.maxIterationsOption().create()); + addOption(DefaultOptionCreator.overwriteOption().create()); + addFlag("usessvd", "ssvd", "Uses SSVD as the eigensolver. Default is the Lanczos solver."); + addOption("reduceTasks", "t", "Number of reducers for SSVD", String.valueOf(REDUCERS)); + addOption("outerProdBlockHeight", "oh", "Block height of outer products for SSVD", String.valueOf(BLOCKHEIGHT)); + addOption("oversampling", "p", "Oversampling parameter for SSVD", String.valueOf(OVERSAMPLING)); + addOption("powerIter", "q", "Additional power iterations for SSVD", String.valueOf(POWERITERS)); + + Map<String, List<String>> parsedArgs = parseArguments(arg0); + if (parsedArgs == null) { + return 0; + } + + Path input = getInputPath(); + Path output = getOutputPath(); + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(conf, getTempPath()); + HadoopUtil.delete(conf, getOutputPath()); + } + int numDims = Integer.parseInt(getOption("dimensions")); + int clusters = Integer.parseInt(getOption("clusters")); + String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); + DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class); + double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION)); + int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); + + Path tempdir = new Path(getOption("tempDir")); + int reducers = Integer.parseInt(getOption("reduceTasks")); + int blockheight = Integer.parseInt(getOption("outerProdBlockHeight")); + int oversampling = Integer.parseInt(getOption("oversampling")); + int poweriters = Integer.parseInt(getOption("powerIter")); + run(conf, input, output, numDims, clusters, measure, convergenceDelta, maxIterations, tempdir, reducers, + blockheight, oversampling, poweriters); + + return 0; + } + + public static void run(Configuration conf, Path input, Path output, int numDims, int clusters, + DistanceMeasure measure, double convergenceDelta, int maxIterations, Path tempDir) + throws IOException, InterruptedException, ClassNotFoundException { + run(conf, input, output, numDims, clusters, measure, convergenceDelta, maxIterations, tempDir, REDUCERS, + BLOCKHEIGHT, OVERSAMPLING, POWERITERS); + } + + /** + * Run the Spectral KMeans clustering on the supplied arguments + * + * @param conf + * the Configuration to be used + * @param input + * the Path to the input tuples directory + * @param output + * the Path to the output directory + * @param numDims + * the int number of dimensions of the affinity matrix + * @param clusters + * the int number of eigenvectors and thus clusters to produce + * @param measure + * the DistanceMeasure for the k-Means calculations + * @param convergenceDelta + * the double convergence delta for the k-Means calculations + * @param maxIterations + * the int maximum number of iterations for the k-Means calculations + * @param tempDir + * Temporary directory for intermediate calculations + * @param numReducers + * Number of reducers + * @param blockHeight + * @param oversampling + * @param poweriters + */ + public static void run(Configuration conf, Path input, Path output, int numDims, int clusters, + DistanceMeasure measure, double convergenceDelta, int maxIterations, Path tempDir, + int numReducers, int blockHeight, int oversampling, int poweriters) + throws IOException, InterruptedException, ClassNotFoundException { + + HadoopUtil.delete(conf, tempDir); + Path outputCalc = new Path(tempDir, "calculations"); + Path outputTmp = new Path(tempDir, "temporary"); + + // Take in the raw CSV text file and split it ourselves, + // creating our own SequenceFiles for the matrices to read later + // (similar to the style of syntheticcontrol.canopy.InputMapper) + Path affSeqFiles = new Path(outputCalc, "seqfile"); + AffinityMatrixInputJob.runJob(input, affSeqFiles, numDims, numDims); + + // Construct the affinity matrix using the newly-created sequence files + DistributedRowMatrix A = new DistributedRowMatrix(affSeqFiles, new Path(outputTmp, "afftmp"), numDims, numDims); + + Configuration depConf = new Configuration(conf); + A.setConf(depConf); + + // Construct the diagonal matrix D (represented as a vector) + Vector D = MatrixDiagonalizeJob.runJob(affSeqFiles, numDims); + + // Calculate the normalized Laplacian of the form: L = D^(-0.5)AD^(-0.5) + DistributedRowMatrix L = VectorMatrixMultiplicationJob.runJob(affSeqFiles, D, new Path(outputCalc, "laplacian"), + new Path(outputCalc, outputCalc)); + L.setConf(depConf); + + Path data; + + // SSVD requires an array of Paths to function. So we pass in an array of length one + Path[] LPath = new Path[1]; + LPath[0] = L.getRowPath(); + + Path SSVDout = new Path(outputCalc, "SSVD"); + + SSVDSolver solveIt = new SSVDSolver(depConf, LPath, SSVDout, blockHeight, clusters, oversampling, numReducers); + + solveIt.setComputeV(false); + solveIt.setComputeU(true); + solveIt.setOverwrite(true); + solveIt.setQ(poweriters); + // solveIt.setBroadcast(false); + solveIt.run(); + data = new Path(solveIt.getUPath()); + + // Normalize the rows of Wt to unit length + // normalize is important because it reduces the occurrence of two unique clusters combining into one + Path unitVectors = new Path(outputCalc, "unitvectors"); + + UnitVectorizerJob.runJob(data, unitVectors); + + DistributedRowMatrix Wt = new DistributedRowMatrix(unitVectors, new Path(unitVectors, "tmp"), clusters, numDims); + Wt.setConf(depConf); + data = Wt.getRowPath(); + + // Generate initial clusters using EigenSeedGenerator which picks rows as centroids if that row contains max + // eigen value in that column + Path initialclusters = EigenSeedGenerator.buildFromEigens(conf, data, + new Path(output, Cluster.INITIAL_CLUSTERS_DIR), clusters, measure); + + // Run the KMeansDriver + Path answer = new Path(output, "kmeans_out"); + KMeansDriver.run(conf, data, initialclusters, answer, convergenceDelta, maxIterations, true, 0.0, false); + + // Restore name to id mapping and read through the cluster assignments + Path mappingPath = new Path(new Path(conf.get("hadoop.tmp.dir")), "generic_input_mapping"); + List<String> mapping = new ArrayList<>(); + FileSystem fs = FileSystem.get(mappingPath.toUri(), conf); + if (fs.exists(mappingPath)) { + SequenceFile.Reader reader = new SequenceFile.Reader(fs, mappingPath, conf); + Text mappingValue = new Text(); + IntWritable mappingIndex = new IntWritable(); + while (reader.next(mappingIndex, mappingValue)) { + String s = mappingValue.toString(); + mapping.add(s); + } + HadoopUtil.delete(conf, mappingPath); + } else { + log.warn("generic input mapping file not found!"); + } + + Path clusteredPointsPath = new Path(answer, "clusteredPoints"); + Path inputPath = new Path(clusteredPointsPath, "part-m-00000"); + int id = 0; + for (Pair<IntWritable, WeightedVectorWritable> record : + new SequenceFileIterable<IntWritable, WeightedVectorWritable>(inputPath, conf)) { + if (!mapping.isEmpty()) { + log.info("{}: {}", mapping.get(id++), record.getFirst().get()); + } else { + log.info("{}: {}", id++, record.getFirst().get()); + } + } + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java new file mode 100644 index 0000000..25806fe --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/BallKMeans.java @@ -0,0 +1,456 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.cluster; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Function; +import com.google.common.base.Preconditions; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import org.apache.mahout.clustering.ClusteringUtils; +import org.apache.mahout.common.Pair; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.WeightedVector; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.apache.mahout.math.random.Multinomial; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Implements a ball k-means algorithm for weighted vectors with probabilistic seeding similar to k-means++. + * The idea is that k-means++ gives good starting clusters and ball k-means can tune up the final result very nicely + * in only a few passes (or even in a single iteration for well-clusterable data). + * + * A good reference for this class of algorithms is "The Effectiveness of Lloyd-Type Methods for the k-Means Problem" + * by Rafail Ostrovsky, Yuval Rabani, Leonard J. Schulman and Chaitanya Swamy. The code here uses the seeding strategy + * as described in section 4.1.1 of that paper and the ball k-means step as described in section 4.2. We support + * multiple iterations in contrast to the algorithm described in the paper. + */ +public class BallKMeans implements Iterable<Centroid> { + /** + * The searcher containing the centroids. + */ + private final UpdatableSearcher centroids; + + /** + * The number of clusters to cluster the data into. + */ + private final int numClusters; + + /** + * The maximum number of iterations of the algorithm to run waiting for the cluster assignments + * to stabilize. If there are no changes in cluster assignment earlier, we can finish early. + */ + private final int maxNumIterations; + + /** + * When deciding which points to include in the new centroid calculation, + * it's preferable to exclude outliers since it increases the rate of convergence. + * So, we calculate the distance from each cluster to its closest neighboring cluster. When + * evaluating the points assigned to a cluster, we compare the distance between the centroid to + * the point with the distance between the centroid and its closest centroid neighbor + * multiplied by this trimFraction. If the distance between the centroid and the point is + * greater, we consider it an outlier and we don't use it. + */ + private final double trimFraction; + + /** + * Selecting the initial centroids is the most important part of the ball k-means clustering. Poor choices, like two + * centroids in the same actual cluster result in a low-quality final result. + * k-means++ initialization yields good quality clusters, especially when using BallKMeans after StreamingKMeans as + * the points have weights. + * Simple, random selection of the points based on their weights is faster but sometimes fails to produce the + * desired number of clusters. + * This field is true if the initialization should be done with k-means++. + */ + private final boolean kMeansPlusPlusInit; + + /** + * When using trimFraction, the weight of each centroid will not be the sum of the weights of + * the vectors assigned to that cluster because outliers are not used to compute the updated + * centroid. + * So, the total weight is probably wrong. This can be fixed by doing another pass over the + * data points and adjusting the weights of each centroid. This doesn't update the coordinates + * of the centroids, but is useful if the weights matter. + */ + private final boolean correctWeights; + + /** + * When running multiple ball k-means passes to get the one with the smallest total cost, can compute the + * overall cost, using all the points for clustering, or reserve a fraction of them, testProbability in a test set. + * The cost is the sum of the distances between each point and its corresponding centroid. + * We then use this set of points to compute the total cost on. We're therefore trying to select the clustering + * that best describes the underlying distribution of the clusters. + * This field is the probability of assigning a given point to the test set. If this is 0, the cost will be computed + * on the entire set of points. + */ + private final double testProbability; + + /** + * Whether or not testProbability > 0, i.e., there exists a non-empty 'test' set. + */ + private final boolean splitTrainTest; + + /** + * How many k-means runs to have. If there's more than one run, we compute the cost of each clustering as described + * above and select the clustering that minimizes the cost. + * Multiple runs are a lot more useful when using the random initialization. With kmeans++, 1-2 runs are enough and + * more runs are not likely to help quality much. + */ + private final int numRuns; + + /** + * Random object to sample values from. + */ + private final Random random; + + public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations) { + // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end, + // there will be 0 points in the test set and 1 run. + this(searcher, numClusters, maxNumIterations, 0.9, true, true, 0.0, 1); + } + + public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations, + boolean kMeansPlusPlusInit, int numRuns) { + // By default, the trimFraction is 0.9, k-means++ is used, the weights will be corrected at the end, + // there will be 10% points of in the test set. + this(searcher, numClusters, maxNumIterations, 0.9, kMeansPlusPlusInit, true, 0.1, numRuns); + } + + public BallKMeans(UpdatableSearcher searcher, int numClusters, int maxNumIterations, + double trimFraction, boolean kMeansPlusPlusInit, boolean correctWeights, + double testProbability, int numRuns) { + Preconditions.checkArgument(searcher.size() == 0, "Searcher must be empty initially to populate with centroids"); + Preconditions.checkArgument(numClusters > 0, "The requested number of clusters must be positive"); + Preconditions.checkArgument(maxNumIterations > 0, "The maximum number of iterations must be positive"); + Preconditions.checkArgument(trimFraction > 0, "The trim fraction must be positive"); + Preconditions.checkArgument(testProbability >= 0 && testProbability < 1, "The testProbability must be in [0, 1)"); + Preconditions.checkArgument(numRuns > 0, "There has to be at least one run"); + + this.centroids = searcher; + this.numClusters = numClusters; + this.maxNumIterations = maxNumIterations; + + this.trimFraction = trimFraction; + this.kMeansPlusPlusInit = kMeansPlusPlusInit; + this.correctWeights = correctWeights; + + this.testProbability = testProbability; + this.splitTrainTest = testProbability > 0; + this.numRuns = numRuns; + + this.random = RandomUtils.getRandom(); + } + + public Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> splitTrainTest( + List<? extends WeightedVector> datapoints) { + // If there will be no points assigned to the test set, return now. + if (testProbability == 0) { + return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>(datapoints, + new ArrayList<WeightedVector>()); + } + + int numTest = (int) (testProbability * datapoints.size()); + Preconditions.checkArgument(numTest > 0 && numTest < datapoints.size(), + "Must have nonzero number of training and test vectors. Asked for %.1f %% of %d vectors for test", + testProbability * 100, datapoints.size()); + + Collections.shuffle(datapoints); + return new Pair<List<? extends WeightedVector>, List<? extends WeightedVector>>( + datapoints.subList(numTest, datapoints.size()), datapoints.subList(0, numTest)); + } + + /** + * Clusters the datapoints in the list doing either random seeding of the centroids or k-means++. + * + * @param datapoints the points to be clustered. + * @return an UpdatableSearcher with the resulting clusters. + */ + public UpdatableSearcher cluster(List<? extends WeightedVector> datapoints) { + Pair<List<? extends WeightedVector>, List<? extends WeightedVector>> trainTestSplit = splitTrainTest(datapoints); + List<Vector> bestCentroids = new ArrayList<>(); + double cost = Double.POSITIVE_INFINITY; + double bestCost = Double.POSITIVE_INFINITY; + for (int i = 0; i < numRuns; ++i) { + centroids.clear(); + if (kMeansPlusPlusInit) { + // Use k-means++ to set initial centroids. + initializeSeedsKMeansPlusPlus(trainTestSplit.getFirst()); + } else { + // Randomly select the initial centroids. + initializeSeedsRandomly(trainTestSplit.getFirst()); + } + // Do k-means iterations with trimmed mean computation (aka ball k-means). + if (numRuns > 1) { + // If the clustering is successful (there are no zero-weight centroids). + iterativeAssignment(trainTestSplit.getFirst()); + // Compute the cost of the clustering and possibly save the centroids. + cost = ClusteringUtils.totalClusterCost( + splitTrainTest ? datapoints : trainTestSplit.getSecond(), centroids); + if (cost < bestCost) { + bestCost = cost; + bestCentroids.clear(); + Iterables.addAll(bestCentroids, centroids); + } + } else { + // If there is only going to be one run, the cost doesn't need to be computed, so we just return the clustering. + iterativeAssignment(datapoints); + return centroids; + } + } + if (bestCost == Double.POSITIVE_INFINITY) { + throw new RuntimeException("No valid clustering was found"); + } + if (cost != bestCost) { + centroids.clear(); + centroids.addAll(bestCentroids); + } + if (correctWeights) { + for (WeightedVector testDatapoint : trainTestSplit.getSecond()) { + WeightedVector closest = (WeightedVector) centroids.searchFirst(testDatapoint, false).getValue(); + closest.setWeight(closest.getWeight() + testDatapoint.getWeight()); + } + } + return centroids; + } + + /** + * Selects some of the original points randomly with probability proportional to their weights. This is much + * less sophisticated than the kmeans++ approach, however it is faster and coupled with + * + * The side effect of this method is to fill the centroids structure itself. + * + * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind. + */ + private void initializeSeedsRandomly(List<? extends WeightedVector> datapoints) { + int numDatapoints = datapoints.size(); + double totalWeight = 0; + for (WeightedVector datapoint : datapoints) { + totalWeight += datapoint.getWeight(); + } + Multinomial<Integer> seedSelector = new Multinomial<>(); + for (int i = 0; i < numDatapoints; ++i) { + seedSelector.add(i, datapoints.get(i).getWeight() / totalWeight); + } + for (int i = 0; i < numClusters; ++i) { + int sample = seedSelector.sample(); + seedSelector.delete(sample); + Centroid centroid = new Centroid(datapoints.get(sample)); + centroid.setIndex(i); + centroids.add(centroid); + } + } + + /** + * Selects some of the original points according to the k-means++ algorithm. The basic idea is that + * points are selected with probability proportional to their distance from any selected point. In + * this version, points have weights which multiply their likelihood of being selected. This is the + * same as if there were as many copies of the same point as indicated by the weight. + * + * This is pretty expensive, but it vastly improves the quality and convergences of the k-means algorithm. + * The basic idea can be made much faster by only processing a random subset of the original points. + * In the context of streaming k-means, the total number of possible seeds will be about k log n so this + * selection will cost O(k^2 (log n)^2) which isn't much worse than the random sampling idea. At + * n = 10^9, the cost of this initialization will be about 10x worse than a reasonable random sampling + * implementation. + * + * The side effect of this method is to fill the centroids structure itself. + * + * @param datapoints The datapoints to select from. These datapoints should be WeightedVectors of some kind. + */ + private void initializeSeedsKMeansPlusPlus(List<? extends WeightedVector> datapoints) { + Preconditions.checkArgument(datapoints.size() > 1, "Must have at least two datapoints points to cluster " + + "sensibly"); + Preconditions.checkArgument(datapoints.size() >= numClusters, + String.format("Must have more datapoints [%d] than clusters [%d]", datapoints.size(), numClusters)); + // Compute the centroid of all of the datapoints. This is then used to compute the squared radius of the datapoints. + Centroid center = new Centroid(datapoints.iterator().next()); + for (WeightedVector row : Iterables.skip(datapoints, 1)) { + center.update(row); + } + + // Given the centroid, we can compute \Delta_1^2(X), the total squared distance for the datapoints + // this accelerates seed selection. + double deltaX = 0; + DistanceMeasure distanceMeasure = centroids.getDistanceMeasure(); + for (WeightedVector row : datapoints) { + deltaX += distanceMeasure.distance(row, center); + } + + // Find the first seed c_1 (and conceptually the second, c_2) as might be done in the 2-means clustering so that + // the probability of selecting c_1 and c_2 is proportional to || c_1 - c_2 ||^2. This is done + // by first selecting c_1 with probability: + // + // p(c_1) = sum_{c_1} || c_1 - c_2 ||^2 \over sum_{c_1, c_2} || c_1 - c_2 ||^2 + // + // This can be simplified to: + // + // p(c_1) = \Delta_1^2(X) + n || c_1 - c ||^2 / (2 n \Delta_1^2(X)) + // + // where c = \sum x / n and \Delta_1^2(X) = sum || x - c ||^2 + // + // All subsequent seeds c_i (including c_2) can then be selected from the remaining points with probability + // proportional to Pr(c_i == x_j) = min_{m < i} || c_m - x_j ||^2. + + // Multinomial distribution of vector indices for the selection seeds. These correspond to + // the indices of the vectors in the original datapoints list. + Multinomial<Integer> seedSelector = new Multinomial<>(); + for (int i = 0; i < datapoints.size(); ++i) { + double selectionProbability = + deltaX + datapoints.size() * distanceMeasure.distance(datapoints.get(i), center); + seedSelector.add(i, selectionProbability); + } + + int selected = random.nextInt(datapoints.size()); + Centroid c_1 = new Centroid(datapoints.get(selected).clone()); + c_1.setIndex(0); + // Construct a set of weighted things which can be used for random selection. Initial weights are + // set to the squared distance from c_1 + for (int i = 0; i < datapoints.size(); ++i) { + WeightedVector row = datapoints.get(i); + double w = distanceMeasure.distance(c_1, row) * 2 * Math.log(1 + row.getWeight()); + seedSelector.set(i, w); + } + + // From here, seeds are selected with probability proportional to: + // + // r_i = min_{c_j} || x_i - c_j ||^2 + // + // when we only have c_1, we have already set these distances and as we select each new + // seed, we update the minimum distances. + centroids.add(c_1); + int clusterIndex = 1; + while (centroids.size() < numClusters) { + // Select according to weights. + int seedIndex = seedSelector.sample(); + Centroid nextSeed = new Centroid(datapoints.get(seedIndex)); + nextSeed.setIndex(clusterIndex++); + centroids.add(nextSeed); + // Don't select this one again. + seedSelector.delete(seedIndex); + // Re-weight everything according to the minimum distance to a seed. + for (int currSeedIndex : seedSelector) { + WeightedVector curr = datapoints.get(currSeedIndex); + double newWeight = nextSeed.getWeight() * distanceMeasure.distance(nextSeed, curr); + if (newWeight < seedSelector.getWeight(currSeedIndex)) { + seedSelector.set(currSeedIndex, newWeight); + } + } + } + } + + /** + * Examines the datapoints and updates cluster centers to be the centroid of the nearest datapoints points. To + * compute a new center for cluster c_i, we average all points that are closer than d_i * trimFraction + * where d_i is + * + * d_i = min_j \sqrt ||c_j - c_i||^2 + * + * By ignoring distant points, the centroids converge more quickly to a good approximation of the + * optimal k-means solution (given good starting points). + * + * @param datapoints the points to cluster. + */ + private void iterativeAssignment(List<? extends WeightedVector> datapoints) { + DistanceMeasure distanceMeasure = centroids.getDistanceMeasure(); + // closestClusterDistances.get(i) is the distance from the i'th cluster to its closest + // neighboring cluster. + List<Double> closestClusterDistances = new ArrayList<>(numClusters); + // clusterAssignments[i] == j means that the i'th point is assigned to the j'th cluster. When + // these don't change, we are done. + // Each point is assigned to the invalid "-1" cluster initially. + List<Integer> clusterAssignments = new ArrayList<>(Collections.nCopies(datapoints.size(), -1)); + + boolean changed = true; + for (int i = 0; changed && i < maxNumIterations; i++) { + changed = false; + // We compute what the distance between each cluster and its closest neighbor is to set a + // proportional distance threshold for points that should be involved in calculating the + // centroid. + closestClusterDistances.clear(); + for (Vector center : centroids) { + // If a centroid has no points assigned to it, the clustering failed. + Vector closestOtherCluster = centroids.searchFirst(center, true).getValue(); + closestClusterDistances.add(distanceMeasure.distance(center, closestOtherCluster)); + } + + // Copies the current cluster centroids to newClusters and sets their weights to 0. This is + // so we calculate the new centroids as we go through the datapoints. + List<Centroid> newCentroids = new ArrayList<>(); + for (Vector centroid : centroids) { + // need a deep copy because we will mutate these values + Centroid newCentroid = (Centroid)centroid.clone(); + newCentroid.setWeight(0); + newCentroids.add(newCentroid); + } + + // Pass over the datapoints computing new centroids. + for (int j = 0; j < datapoints.size(); ++j) { + WeightedVector datapoint = datapoints.get(j); + // Get the closest cluster this point belongs to. + WeightedThing<Vector> closestPair = centroids.searchFirst(datapoint, false); + int closestIndex = ((WeightedVector) closestPair.getValue()).getIndex(); + double closestDistance = closestPair.getWeight(); + // Update its cluster assignment if necessary. + if (closestIndex != clusterAssignments.get(j)) { + changed = true; + clusterAssignments.set(j, closestIndex); + } + // Only update if the datapoints point is near enough. What this means is that the weight + // of outliers is NOT taken into account and the final weights of the centroids will + // reflect this (it will be less or equal to the initial sum of the weights). + if (closestDistance < trimFraction * closestClusterDistances.get(closestIndex)) { + newCentroids.get(closestIndex).update(datapoint); + } + } + // Add the new centers back into searcher. + centroids.clear(); + centroids.addAll(newCentroids); + } + + if (correctWeights) { + for (Vector v : centroids) { + ((Centroid)v).setWeight(0); + } + for (WeightedVector datapoint : datapoints) { + Centroid closestCentroid = (Centroid) centroids.searchFirst(datapoint, false).getValue(); + closestCentroid.setWeight(closestCentroid.getWeight() + datapoint.getWeight()); + } + } + } + + @Override + public Iterator<Centroid> iterator() { + return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() { + @Override + public Centroid apply(Vector input) { + Preconditions.checkArgument(input instanceof Centroid, "Non-centroid in centroids " + + "searcher"); + //noinspection ConstantConditions + return (Centroid)input; + } + }); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java new file mode 100644 index 0000000..604bc9d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/cluster/StreamingKMeans.java @@ -0,0 +1,368 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.cluster; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Random; + +import com.google.common.base.Function; +import com.google.common.collect.Iterables; +import com.google.common.collect.Iterators; +import org.apache.mahout.common.RandomUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Matrix; +import org.apache.mahout.math.MatrixSlice; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.jet.math.Constants; +import org.apache.mahout.math.neighborhood.UpdatableSearcher; +import org.apache.mahout.math.random.WeightedThing; + +/** + * Implements a streaming k-means algorithm for weighted vectors. + * The goal clustering points one at a time, especially useful for MapReduce mappers that get inputs one at a time. + * + * A rough description of the algorithm: + * Suppose there are l clusters at one point and a new point p is added. + * The new point can either be added to one of the existing l clusters or become a new cluster. To decide: + * - let c be the closest cluster to point p; + * - let d be the distance between c and p; + * - if d > distanceCutoff, create a new cluster from p (p is too far away from the clusters to be part of them; + * distanceCutoff represents the largest distance from a point its assigned cluster's centroid); + * - else (d <= distanceCutoff), create a new cluster with probability d / distanceCutoff (the probability of creating + * a new cluster increases as d increases). + * There will be either l points or l + 1 points after processing a new point. + * + * As the number of clusters increases, it will go over the numClusters limit (numClusters represents a recommendation + * for the number of clusters that there should be at the end). To decrease the number of clusters the existing clusters + * are treated as data points and are re-clustered (collapsed). This tends to make the number of clusters go down. + * If the number of clusters is still too high, distanceCutoff is increased. + * + * For more details, see: + * - "Streaming k-means approximation" by N. Ailon, R. Jaiswal, C. Monteleoni + * http://books.nips.cc/papers/files/nips22/NIPS2009_1085.pdf + * - "Fast and Accurate k-means for Large Datasets" by M. Shindler, A. Wong, A. Meyerson, + * http://books.nips.cc/papers/files/nips24/NIPS2011_1271.pdf + */ +public class StreamingKMeans implements Iterable<Centroid> { + /** + * The searcher containing the centroids that resulted from the clustering of points until now. When adding a new + * point we either assign it to one of the existing clusters in this searcher or create a new centroid for it. + */ + private final UpdatableSearcher centroids; + + /** + * The estimated number of clusters to cluster the data in. If the actual number of clusters increases beyond this + * limit, the clusters will be "collapsed" (re-clustered, by treating them as data points). This doesn't happen + * recursively and a collapse might not necessarily make the number of actual clusters drop to less than this limit. + * + * If the goal is clustering a large data set into k clusters, numClusters SHOULD NOT BE SET to k. StreamingKMeans is + * useful to reduce the size of the data set by the mappers so that it can fit into memory in one reducer that runs + * BallKMeans. + * + * It is NOT MEANT to cluster the data into k clusters in one pass because it can't guarantee that there will in fact + * be k clusters in total. This is because of the dynamic nature of numClusters over the course of the runtime. + * To get an exact number of clusters, another clustering algorithm needs to be applied to the results. + */ + private int numClusters; + + /** + * The number of data points seen so far. This is important for re-estimating numClusters when deciding to collapse + * the existing clusters. + */ + private int numProcessedDatapoints = 0; + + /** + * This is the current value of the distance cutoff. Points which are much closer than this to a centroid will stick + * to it almost certainly. Points further than this to any centroid will form a new cluster. + * + * This increases (is multiplied by beta) when a cluster collapse did not make the number of clusters drop to below + * numClusters (it effectively increases the tolerance for cluster compactness discouraging the creation of new + * clusters). Since a collapse only happens when centroids.size() > clusterOvershoot * numClusters, the cutoff + * increases when the collapse didn't at least remove the slack in the number of clusters. + */ + private double distanceCutoff; + + /** + * Parameter that controls the growth of the distanceCutoff. After n increases of the + * distanceCutoff starting at d_0, the final value is d_0 * beta^n (distance cutoffs increase following a geometric + * progression with ratio beta). + */ + private final double beta; + + /** + * Multiplying clusterLogFactor with numProcessedDatapoints gets an estimate of the suggested + * number of clusters. This mirrors the recommended number of clusters for n points where there should be k actual + * clusters, k * log n. In the case of our estimate we use clusterLogFactor * log(numProcessedDataPoints). + * + * It is important to note that numClusters is NOT k. It is an estimate of k * log n. + */ + private final double clusterLogFactor; + + /** + * Centroids are collapsed when the number of clusters becomes greater than clusterOvershoot * numClusters. This + * effectively means having a slack in numClusters so that the actual number of centroids, centroids.size() tracks + * numClusters approximately. The idea is that the actual number of clusters should be at least numClusters but not + * much more (so that we don't end up having 1 cluster / point). + */ + private final double clusterOvershoot; + + /** + * Random object to sample values from. + */ + private final Random random = RandomUtils.getRandom(); + + /** + * Calls StreamingKMeans(searcher, numClusters, 1.3, 10, 2). + * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int, + * double, double, double, double) + */ + public StreamingKMeans(UpdatableSearcher searcher, int numClusters) { + this(searcher, numClusters, 1.0 / numClusters, 1.3, 20, 2); + } + + /** + * Calls StreamingKMeans(searcher, numClusters, distanceCutoff, 1.3, 10, 2). + * @see StreamingKMeans#StreamingKMeans(org.apache.mahout.math.neighborhood.UpdatableSearcher, int, + * double, double, double, double) + */ + public StreamingKMeans(UpdatableSearcher searcher, int numClusters, double distanceCutoff) { + this(searcher, numClusters, distanceCutoff, 1.3, 20, 2); + } + + /** + * Creates a new StreamingKMeans class given a searcher and the number of clusters to generate. + * + * @param searcher A Searcher that is used for performing nearest neighbor search. It MUST BE + * EMPTY initially because it will be used to keep track of the cluster + * centroids. + * @param numClusters An estimated number of clusters to generate for the data points. + * This can adjusted, but the actual number will depend on the data. The + * @param distanceCutoff The initial distance cutoff representing the value of the + * distance between a point and its closest centroid after which + * the new point will definitely be assigned to a new cluster. + * @param beta Ratio of geometric progression to use when increasing distanceCutoff. After n increases, distanceCutoff + * becomes distanceCutoff * beta^n. A smaller value increases the distanceCutoff less aggressively. + * @param clusterLogFactor Value multiplied with the number of points counted so far estimating the number of clusters + * to aim for. If the final number of clusters is known and this clustering is only for a + * sketch of the data, this can be the final number of clusters, k. + * @param clusterOvershoot Multiplicative slack factor for slowing down the collapse of the clusters. + */ + public StreamingKMeans(UpdatableSearcher searcher, int numClusters, + double distanceCutoff, double beta, double clusterLogFactor, double clusterOvershoot) { + this.centroids = searcher; + this.numClusters = numClusters; + this.distanceCutoff = distanceCutoff; + this.beta = beta; + this.clusterLogFactor = clusterLogFactor; + this.clusterOvershoot = clusterOvershoot; + } + + /** + * @return an Iterator to the Centroids contained in this clusterer. + */ + @Override + public Iterator<Centroid> iterator() { + return Iterators.transform(centroids.iterator(), new Function<Vector, Centroid>() { + @Override + public Centroid apply(Vector input) { + return (Centroid)input; + } + }); + } + + /** + * Cluster the rows of a matrix, treating them as Centroids with weight 1. + * @param data matrix whose rows are to be clustered. + * @return the UpdatableSearcher containing the resulting centroids. + */ + public UpdatableSearcher cluster(Matrix data) { + return cluster(Iterables.transform(data, new Function<MatrixSlice, Centroid>() { + @Override + public Centroid apply(MatrixSlice input) { + // The key in a Centroid is actually the MatrixSlice's index. + return Centroid.create(input.index(), input.vector()); + } + })); + } + + /** + * Cluster the data points in an Iterable<Centroid>. + * @param datapoints Iterable whose elements are to be clustered. + * @return the UpdatableSearcher containing the resulting centroids. + */ + public UpdatableSearcher cluster(Iterable<Centroid> datapoints) { + return clusterInternal(datapoints, false); + } + + /** + * Cluster one data point. + * @param datapoint to be clustered. + * @return the UpdatableSearcher containing the resulting centroids. + */ + public UpdatableSearcher cluster(final Centroid datapoint) { + return cluster(new Iterable<Centroid>() { + @Override + public Iterator<Centroid> iterator() { + return new Iterator<Centroid>() { + private boolean accessed = false; + + @Override + public boolean hasNext() { + return !accessed; + } + + @Override + public Centroid next() { + accessed = true; + return datapoint; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + }); + } + + /** + * @return the number of clusters computed from the points until now. + */ + public int getNumClusters() { + return centroids.size(); + } + + /** + * Internal clustering method that gets called from the other wrappers. + * @param datapoints Iterable of data points to be clustered. + * @param collapseClusters whether this is an "inner" clustering and the datapoints are the previously computed + * centroids. Some logic is different to ensure counters are consistent but it behaves + * nearly the same. + * @return the UpdatableSearcher containing the resulting centroids. + */ + private UpdatableSearcher clusterInternal(Iterable<Centroid> datapoints, boolean collapseClusters) { + Iterator<Centroid> datapointsIterator = datapoints.iterator(); + if (!datapointsIterator.hasNext()) { + return centroids; + } + + int oldNumProcessedDataPoints = numProcessedDatapoints; + // We clear the centroids we have in case of cluster collapse, the old clusters are the + // datapoints but we need to re-cluster them. + if (collapseClusters) { + centroids.clear(); + numProcessedDatapoints = 0; + } + + if (centroids.size() == 0) { + // Assign the first datapoint to the first cluster. + // Adding a vector to a searcher would normally just reference the copy, + // but we could potentially mutate it and so we need to make a clone. + centroids.add(datapointsIterator.next().clone()); + ++numProcessedDatapoints; + } + + // To cluster, we scan the data and either add each point to the nearest group or create a new group. + // when we get too many groups, we need to increase the threshold and rescan our current groups + while (datapointsIterator.hasNext()) { + Centroid row = datapointsIterator.next(); + // Get the closest vector and its weight as a WeightedThing<Vector>. + // The weight of the WeightedThing is the distance to the query and the value is a + // reference to one of the vectors we added to the searcher previously. + WeightedThing<Vector> closestPair = centroids.searchFirst(row, false); + + // We get a uniformly distributed random number between 0 and 1 and compare it with the + // distance to the closest cluster divided by the distanceCutoff. + // This is so that if the closest cluster is further than distanceCutoff, + // closestPair.getWeight() / distanceCutoff > 1 which will trigger the creation of a new + // cluster anyway. + // However, if the ratio is less than 1, we want to create a new cluster with probability + // proportional to the distance to the closest cluster. + double sample = random.nextDouble(); + if (sample < row.getWeight() * closestPair.getWeight() / distanceCutoff) { + // Add new centroid, note that the vector is copied because we may mutate it later. + centroids.add(row.clone()); + } else { + // Merge the new point with the existing centroid. This will update the centroid's actual + // position. + // We know that all the points we inserted in the centroids searcher are (or extend) + // WeightedVector, so the cast will always succeed. + Centroid centroid = (Centroid) closestPair.getValue(); + + // We will update the centroid by removing it from the searcher and reinserting it to + // ensure consistency. + if (!centroids.remove(centroid, Constants.EPSILON)) { + throw new RuntimeException("Unable to remove centroid"); + } + centroid.update(row); + centroids.add(centroid); + + } + ++numProcessedDatapoints; + + if (!collapseClusters && centroids.size() > clusterOvershoot * numClusters) { + numClusters = (int) Math.max(numClusters, clusterLogFactor * Math.log(numProcessedDatapoints)); + + List<Centroid> shuffled = new ArrayList<>(); + for (Vector vector : centroids) { + shuffled.add((Centroid) vector); + } + Collections.shuffle(shuffled); + // Re-cluster using the shuffled centroids as data points. The centroids member variable + // is modified directly. + clusterInternal(shuffled, true); + + if (centroids.size() > numClusters) { + distanceCutoff *= beta; + } + } + } + + if (collapseClusters) { + numProcessedDatapoints = oldNumProcessedDataPoints; + } + return centroids; + } + + public void reindexCentroids() { + int numCentroids = 0; + for (Centroid centroid : this) { + centroid.setIndex(numCentroids++); + } + } + + /** + * @return the distanceCutoff (an upper bound for the maximum distance within a cluster). + */ + public double getDistanceCutoff() { + return distanceCutoff; + } + + public void setDistanceCutoff(double distanceCutoff) { + this.distanceCutoff = distanceCutoff; + } + + public DistanceMeasure getDistanceMeasure() { + return centroids.getDistanceMeasure(); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java new file mode 100644 index 0000000..a41940b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/streaming/mapreduce/CentroidWritable.java @@ -0,0 +1,88 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.streaming.mapreduce; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.math.Centroid; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +public class CentroidWritable implements Writable { + private Centroid centroid = null; + + public CentroidWritable() {} + + public CentroidWritable(Centroid centroid) { + this.centroid = centroid; + } + + public Centroid getCentroid() { + return centroid; + } + + @Override + public void write(DataOutput dataOutput) throws IOException { + dataOutput.writeInt(centroid.getIndex()); + dataOutput.writeDouble(centroid.getWeight()); + VectorWritable.writeVector(dataOutput, centroid.getVector()); + } + + @Override + public void readFields(DataInput dataInput) throws IOException { + if (centroid == null) { + centroid = read(dataInput); + return; + } + centroid.setIndex(dataInput.readInt()); + centroid.setWeight(dataInput.readDouble()); + centroid.assign(VectorWritable.readVector(dataInput)); + } + + public static Centroid read(DataInput dataInput) throws IOException { + int index = dataInput.readInt(); + double weight = dataInput.readDouble(); + Vector v = VectorWritable.readVector(dataInput); + return new Centroid(index, v, weight); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof CentroidWritable)) { + return false; + } + CentroidWritable writable = (CentroidWritable) o; + return centroid.equals(writable.centroid); + } + + @Override + public int hashCode() { + return centroid.hashCode(); + } + + @Override + public String toString() { + return centroid.toString(); + } +}
