http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java new file mode 100644 index 0000000..fffa7f9 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java @@ -0,0 +1,161 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.classify; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.clustering.iterator.ClusteringPolicy; +import org.apache.mahout.clustering.iterator.DistanceMeasureCluster; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator; +import org.apache.mahout.math.NamedVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; + +/** + * Mapper for classifying vectors into clusters. + */ +public class ClusterClassificationMapper extends + Mapper<WritableComparable<?>,VectorWritable,IntWritable,WeightedVectorWritable> { + + private double threshold; + private List<Cluster> clusterModels; + private ClusterClassifier clusterClassifier; + private IntWritable clusterId; + private boolean emitMostLikely; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + super.setup(context); + + Configuration conf = context.getConfiguration(); + String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN); + threshold = conf.getFloat(ClusterClassificationConfigKeys.OUTLIER_REMOVAL_THRESHOLD, 0.0f); + emitMostLikely = conf.getBoolean(ClusterClassificationConfigKeys.EMIT_MOST_LIKELY, false); + + clusterModels = new ArrayList<>(); + + if (clustersIn != null && !clustersIn.isEmpty()) { + Path clustersInPath = new Path(clustersIn); + clusterModels = populateClusterModels(clustersInPath, conf); + ClusteringPolicy policy = ClusterClassifier + .readPolicy(finalClustersPath(clustersInPath)); + clusterClassifier = new ClusterClassifier(clusterModels, policy); + } + clusterId = new IntWritable(); + } + + /** + * Mapper which classifies the vectors to respective clusters. + */ + @Override + protected void map(WritableComparable<?> key, VectorWritable vw, Context context) + throws IOException, InterruptedException { + if (!clusterModels.isEmpty()) { + // Converting to NamedVectors to preserve the vectorId else its not obvious as to which point + // belongs to which cluster - fix for MAHOUT-1410 + Class<? extends Vector> vectorClass = vw.get().getClass(); + Vector vector = vw.get(); + if (!vectorClass.equals(NamedVector.class)) { + if (key.getClass().equals(Text.class)) { + vector = new NamedVector(vector, key.toString()); + } else if (key.getClass().equals(IntWritable.class)) { + vector = new NamedVector(vector, Integer.toString(((IntWritable) key).get())); + } + } + Vector pdfPerCluster = clusterClassifier.classify(vector); + if (shouldClassify(pdfPerCluster)) { + if (emitMostLikely) { + int maxValueIndex = pdfPerCluster.maxValueIndex(); + write(new VectorWritable(vector), context, maxValueIndex, 1.0); + } else { + writeAllAboveThreshold(new VectorWritable(vector), context, pdfPerCluster); + } + } + } + } + + private void writeAllAboveThreshold(VectorWritable vw, Context context, + Vector pdfPerCluster) throws IOException, InterruptedException { + for (Element pdf : pdfPerCluster.nonZeroes()) { + if (pdf.get() >= threshold) { + int clusterIndex = pdf.index(); + write(vw, context, clusterIndex, pdf.get()); + } + } + } + + private void write(VectorWritable vw, Context context, int clusterIndex, double weight) + throws IOException, InterruptedException { + Cluster cluster = clusterModels.get(clusterIndex); + clusterId.set(cluster.getId()); + + DistanceMeasureCluster distanceMeasureCluster = (DistanceMeasureCluster) cluster; + DistanceMeasure distanceMeasure = distanceMeasureCluster.getMeasure(); + double distance = distanceMeasure.distance(cluster.getCenter(), vw.get()); + + Map<Text, Text> props = new HashMap<>(); + props.put(new Text("distance"), new Text(Double.toString(distance))); + context.write(clusterId, new WeightedPropertyVectorWritable(weight, vw.get(), props)); + } + + public static List<Cluster> populateClusterModels(Path clusterOutputPath, Configuration conf) throws IOException { + List<Cluster> clusters = new ArrayList<>(); + FileSystem fileSystem = clusterOutputPath.getFileSystem(conf); + FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter()); + Iterator<?> it = new SequenceFileDirValueIterator<>( + clusterFiles[0].getPath(), PathType.LIST, PathFilters.partFilter(), + null, false, conf); + while (it.hasNext()) { + ClusterWritable next = (ClusterWritable) it.next(); + Cluster cluster = next.getValue(); + cluster.configure(conf); + clusters.add(cluster); + } + return clusters; + } + + private boolean shouldClassify(Vector pdfPerCluster) { + return pdfPerCluster.maxValue() >= threshold; + } + + private static Path finalClustersPath(Path clusterOutputPath) throws IOException { + FileSystem fileSystem = clusterOutputPath.getFileSystem(new Configuration()); + FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, PathFilters.finalPartFilter()); + return clusterFiles[0].getPath(); + } +}
http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java new file mode 100644 index 0000000..dcd4062 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/ClusterClassifier.java @@ -0,0 +1,231 @@ +/* Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.classify; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; + +import com.google.common.io.Closeables; +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.SequenceFile; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.AbstractVectorClassifier; +import org.apache.mahout.classifier.OnlineLearner; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.clustering.iterator.ClusteringPolicy; +import org.apache.mahout.clustering.iterator.ClusteringPolicyWritable; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +/** + * This classifier works with any ClusteringPolicy and its associated Clusters. + * It is initialized with a policy and a list of compatible clusters and + * thereafter it can classify any new Vector into one or more of the clusters + * based upon the pdf() function which each cluster supports. + * <p/> + * In addition, it is an OnlineLearner and can be trained. Training amounts to + * asking the actual model to observe the vector and closing the classifier + * causes all the models to computeParameters. + * <p/> + * Because a ClusterClassifier implements Writable, it can be written-to and + * read-from a sequence file as a single entity. For sequential and MapReduce + * clustering in conjunction with a ClusterIterator; however, it utilizes an + * exploded file format. In this format, the iterator writes the policy to a + * single POLICY_FILE_NAME file in the clustersOut directory and the models are + * written to one or more part-n files so that multiple reducers may employed to + * produce them. + */ +public class ClusterClassifier extends AbstractVectorClassifier implements OnlineLearner, Writable { + + private static final String POLICY_FILE_NAME = "_policy"; + + private List<Cluster> models; + + private String modelClass; + + private ClusteringPolicy policy; + + /** + * The public constructor accepts a list of clusters to become the models + * + * @param models a List<Cluster> + * @param policy a ClusteringPolicy + */ + public ClusterClassifier(List<Cluster> models, ClusteringPolicy policy) { + this.models = models; + modelClass = models.get(0).getClass().getName(); + this.policy = policy; + } + + // needed for serialization/De-serialization + public ClusterClassifier() { + } + + // only used by MR ClusterIterator + protected ClusterClassifier(ClusteringPolicy policy) { + this.policy = policy; + } + + @Override + public Vector classify(Vector instance) { + return policy.classify(instance, this); + } + + @Override + public double classifyScalar(Vector instance) { + if (models.size() == 2) { + double pdf0 = models.get(0).pdf(new VectorWritable(instance)); + double pdf1 = models.get(1).pdf(new VectorWritable(instance)); + return pdf0 / (pdf0 + pdf1); + } + throw new IllegalStateException(); + } + + @Override + public int numCategories() { + return models.size(); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeInt(models.size()); + out.writeUTF(modelClass); + new ClusteringPolicyWritable(policy).write(out); + for (Cluster cluster : models) { + cluster.write(out); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + int size = in.readInt(); + modelClass = in.readUTF(); + models = new ArrayList<>(); + ClusteringPolicyWritable clusteringPolicyWritable = new ClusteringPolicyWritable(); + clusteringPolicyWritable.readFields(in); + policy = clusteringPolicyWritable.getValue(); + for (int i = 0; i < size; i++) { + Cluster element = ClassUtils.instantiateAs(modelClass, Cluster.class); + element.readFields(in); + models.add(element); + } + } + + @Override + public void train(int actual, Vector instance) { + models.get(actual).observe(new VectorWritable(instance)); + } + + /** + * Train the models given an additional weight. Unique to ClusterClassifier + * + * @param actual the int index of a model + * @param data a data Vector + * @param weight a double weighting factor + */ + public void train(int actual, Vector data, double weight) { + models.get(actual).observe(new VectorWritable(data), weight); + } + + @Override + public void train(long trackingKey, String groupKey, int actual, Vector instance) { + models.get(actual).observe(new VectorWritable(instance)); + } + + @Override + public void train(long trackingKey, int actual, Vector instance) { + models.get(actual).observe(new VectorWritable(instance)); + } + + @Override + public void close() { + policy.close(this); + } + + public List<Cluster> getModels() { + return models; + } + + public ClusteringPolicy getPolicy() { + return policy; + } + + public void writeToSeqFiles(Path path) throws IOException { + writePolicy(policy, path); + Configuration config = new Configuration(); + FileSystem fs = FileSystem.get(path.toUri(), config); + ClusterWritable cw = new ClusterWritable(); + for (int i = 0; i < models.size(); i++) { + try (SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, + new Path(path, "part-" + String.format(Locale.ENGLISH, "%05d", i)), IntWritable.class, + ClusterWritable.class)) { + Cluster cluster = models.get(i); + cw.setValue(cluster); + Writable key = new IntWritable(i); + writer.append(key, cw); + } + } + } + + public void readFromSeqFiles(Configuration conf, Path path) throws IOException { + Configuration config = new Configuration(); + List<Cluster> clusters = new ArrayList<>(); + for (ClusterWritable cw : new SequenceFileDirValueIterable<ClusterWritable>(path, PathType.LIST, + PathFilters.logsCRCFilter(), config)) { + Cluster cluster = cw.getValue(); + cluster.configure(conf); + clusters.add(cluster); + } + this.models = clusters; + modelClass = models.get(0).getClass().getName(); + this.policy = readPolicy(path); + } + + public static ClusteringPolicy readPolicy(Path path) throws IOException { + Path policyPath = new Path(path, POLICY_FILE_NAME); + Configuration config = new Configuration(); + FileSystem fs = FileSystem.get(policyPath.toUri(), config); + SequenceFile.Reader reader = new SequenceFile.Reader(fs, policyPath, config); + Text key = new Text(); + ClusteringPolicyWritable cpw = new ClusteringPolicyWritable(); + reader.next(key, cpw); + Closeables.close(reader, true); + return cpw.getValue(); + } + + public static void writePolicy(ClusteringPolicy policy, Path path) throws IOException { + Path policyPath = new Path(path, POLICY_FILE_NAME); + Configuration config = new Configuration(); + FileSystem fs = FileSystem.get(policyPath.toUri(), config); + SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, policyPath, Text.class, + ClusteringPolicyWritable.class); + writer.append(new Text(), new ClusteringPolicyWritable(policy)); + Closeables.close(writer, false); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java new file mode 100644 index 0000000..567659b --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedPropertyVectorWritable.java @@ -0,0 +1,95 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.classify; + +import org.apache.hadoop.io.Text; +import org.apache.mahout.clustering.AbstractCluster; +import org.apache.mahout.math.Vector; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +public class WeightedPropertyVectorWritable extends WeightedVectorWritable { + + private Map<Text, Text> properties; + + public WeightedPropertyVectorWritable() { + } + + public WeightedPropertyVectorWritable(Map<Text, Text> properties) { + this.properties = properties; + } + + public WeightedPropertyVectorWritable(double weight, Vector vector, Map<Text, Text> properties) { + super(weight, vector); + this.properties = properties; + } + + public Map<Text, Text> getProperties() { + return properties; + } + + public void setProperties(Map<Text, Text> properties) { + this.properties = properties; + } + + @Override + public void readFields(DataInput in) throws IOException { + super.readFields(in); + int size = in.readInt(); + if (size > 0) { + properties = new HashMap<>(); + for (int i = 0; i < size; i++) { + Text key = new Text(in.readUTF()); + Text val = new Text(in.readUTF()); + properties.put(key, val); + } + } + } + + @Override + public void write(DataOutput out) throws IOException { + super.write(out); + out.writeInt(properties != null ? properties.size() : 0); + if (properties != null) { + for (Map.Entry<Text, Text> entry : properties.entrySet()) { + out.writeUTF(entry.getKey().toString()); + out.writeUTF(entry.getValue().toString()); + } + } + } + + @Override + public String toString() { + Vector vector = getVector(); + StringBuilder bldr = new StringBuilder("wt: ").append(getWeight()).append(' '); + if (properties != null && !properties.isEmpty()) { + for (Map.Entry<Text, Text> entry : properties.entrySet()) { + bldr.append(entry.getKey().toString()).append(": ").append(entry.getValue().toString()).append(' '); + } + } + bldr.append(" vec: ").append(vector == null ? "null" : AbstractCluster.formatVector(vector, null)); + return bldr.toString(); + } + + +} + http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java new file mode 100644 index 0000000..510dd39 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/classify/WeightedVectorWritable.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.classify; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.AbstractCluster; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +public class WeightedVectorWritable implements Writable { + + private final VectorWritable vectorWritable = new VectorWritable(); + private double weight; + + public WeightedVectorWritable() { + } + + public WeightedVectorWritable(double weight, Vector vector) { + this.vectorWritable.set(vector); + this.weight = weight; + } + + public Vector getVector() { + return vectorWritable.get(); + } + + public void setVector(Vector vector) { + vectorWritable.set(vector); + } + + public double getWeight() { + return weight; + } + + @Override + public void readFields(DataInput in) throws IOException { + vectorWritable.readFields(in); + weight = in.readDouble(); + } + + @Override + public void write(DataOutput out) throws IOException { + vectorWritable.write(out); + out.writeDouble(weight); + } + + @Override + public String toString() { + Vector vector = vectorWritable.get(); + return weight + ": " + (vector == null ? "null" : AbstractCluster.formatVector(vector, null)); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java new file mode 100644 index 0000000..ff02a4c --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansClusterer.java @@ -0,0 +1,59 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.fuzzykmeans; + +import java.util.Collection; +import java.util.List; + +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.Vector; + +public class FuzzyKMeansClusterer { + + private static final double MINIMAL_VALUE = 0.0000000001; + + private double m = 2.0; // default value + + public Vector computePi(Collection<SoftCluster> clusters, List<Double> clusterDistanceList) { + Vector pi = new DenseVector(clusters.size()); + for (int i = 0; i < clusters.size(); i++) { + double probWeight = computeProbWeight(clusterDistanceList.get(i), clusterDistanceList); + pi.set(i, probWeight); + } + return pi; + } + + /** Computes the probability of a point belonging to a cluster */ + public double computeProbWeight(double clusterDistance, Iterable<Double> clusterDistanceList) { + if (clusterDistance == 0) { + clusterDistance = MINIMAL_VALUE; + } + double denom = 0.0; + for (double eachCDist : clusterDistanceList) { + if (eachCDist == 0.0) { + eachCDist = MINIMAL_VALUE; + } + denom += Math.pow(clusterDistance / eachCDist, 2.0 / (m - 1)); + } + return 1.0 / denom; + } + + public void setM(double m) { + this.m = m; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java new file mode 100644 index 0000000..98eb944 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansDriver.java @@ -0,0 +1,324 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.fuzzykmeans; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.util.ToolRunner; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassificationDriver; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.clustering.iterator.ClusterIterator; +import org.apache.mahout.clustering.iterator.ClusteringPolicy; +import org.apache.mahout.clustering.iterator.FuzzyKMeansClusteringPolicy; +import org.apache.mahout.clustering.kmeans.RandomSeedGenerator; +import org.apache.mahout.clustering.topdown.PathDirectory; +import org.apache.mahout.common.AbstractJob; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.HadoopUtil; +import org.apache.mahout.common.commandline.DefaultOptionCreator; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.common.distance.SquaredEuclideanDistanceMeasure; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class FuzzyKMeansDriver extends AbstractJob { + + public static final String M_OPTION = "m"; + + private static final Logger log = LoggerFactory.getLogger(FuzzyKMeansDriver.class); + + public static void main(String[] args) throws Exception { + ToolRunner.run(new Configuration(), new FuzzyKMeansDriver(), args); + } + + @Override + public int run(String[] args) throws Exception { + + addInputOption(); + addOutputOption(); + addOption(DefaultOptionCreator.distanceMeasureOption().create()); + addOption(DefaultOptionCreator.clustersInOption() + .withDescription("The input centroids, as Vectors. Must be a SequenceFile of Writable, Cluster/Canopy. " + + "If k is also specified, then a random set of vectors will be selected" + + " and written out to this path first") + .create()); + addOption(DefaultOptionCreator.numClustersOption() + .withDescription("The k in k-Means. If specified, then a random selection of k Vectors will be chosen" + + " as the Centroid and written to the clusters input path.").create()); + addOption(DefaultOptionCreator.convergenceOption().create()); + addOption(DefaultOptionCreator.maxIterationsOption().create()); + addOption(DefaultOptionCreator.overwriteOption().create()); + addOption(M_OPTION, M_OPTION, "coefficient normalization factor, must be greater than 1", true); + addOption(DefaultOptionCreator.clusteringOption().create()); + addOption(DefaultOptionCreator.emitMostLikelyOption().create()); + addOption(DefaultOptionCreator.thresholdOption().create()); + addOption(DefaultOptionCreator.methodOption().create()); + addOption(DefaultOptionCreator.useSetRandomSeedOption().create()); + + if (parseArguments(args) == null) { + return -1; + } + + Path input = getInputPath(); + Path clusters = new Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION)); + Path output = getOutputPath(); + String measureClass = getOption(DefaultOptionCreator.DISTANCE_MEASURE_OPTION); + if (measureClass == null) { + measureClass = SquaredEuclideanDistanceMeasure.class.getName(); + } + double convergenceDelta = Double.parseDouble(getOption(DefaultOptionCreator.CONVERGENCE_DELTA_OPTION)); + float fuzziness = Float.parseFloat(getOption(M_OPTION)); + + int maxIterations = Integer.parseInt(getOption(DefaultOptionCreator.MAX_ITERATIONS_OPTION)); + if (hasOption(DefaultOptionCreator.OVERWRITE_OPTION)) { + HadoopUtil.delete(getConf(), output); + } + boolean emitMostLikely = Boolean.parseBoolean(getOption(DefaultOptionCreator.EMIT_MOST_LIKELY_OPTION)); + double threshold = Double.parseDouble(getOption(DefaultOptionCreator.THRESHOLD_OPTION)); + DistanceMeasure measure = ClassUtils.instantiateAs(measureClass, DistanceMeasure.class); + + if (hasOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)) { + int numClusters = Integer.parseInt(getOption(DefaultOptionCreator.NUM_CLUSTERS_OPTION)); + + Long seed = null; + if (hasOption(DefaultOptionCreator.RANDOM_SEED)) { + seed = Long.parseLong(getOption(DefaultOptionCreator.RANDOM_SEED)); + } + + clusters = RandomSeedGenerator.buildRandom(getConf(), input, clusters, numClusters, measure, seed); + } + + boolean runClustering = hasOption(DefaultOptionCreator.CLUSTERING_OPTION); + boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase( + DefaultOptionCreator.SEQUENTIAL_METHOD); + + run(getConf(), + input, + clusters, + output, + convergenceDelta, + maxIterations, + fuzziness, + runClustering, + emitMostLikely, + threshold, + runSequential); + return 0; + } + + /** + * Iterate over the input vectors to produce clusters and, if requested, use the + * results of the final iteration to cluster the input vectors. + * + * @param input + * the directory pathname for input points + * @param clustersIn + * the directory pathname for initial & computed clusters + * @param output + * the directory pathname for output points + * @param convergenceDelta +* the convergence delta value + * @param maxIterations +* the maximum number of iterations + * @param m +* the fuzzification factor, see +* http://en.wikipedia.org/wiki/Data_clustering#Fuzzy_c-means_clustering + * @param runClustering +* true if points are to be clustered after iterations complete + * @param emitMostLikely +* a boolean if true emit only most likely cluster for each point + * @param threshold +* a double threshold value emits all clusters having greater pdf (emitMostLikely = false) + * @param runSequential if true run in sequential execution mode + */ + public static void run(Path input, + Path clustersIn, + Path output, + double convergenceDelta, + int maxIterations, + float m, + boolean runClustering, + boolean emitMostLikely, + double threshold, + boolean runSequential) throws IOException, ClassNotFoundException, InterruptedException { + Configuration conf = new Configuration(); + Path clustersOut = buildClusters(conf, + input, + clustersIn, + output, + convergenceDelta, + maxIterations, + m, + runSequential); + if (runClustering) { + log.info("Clustering "); + clusterData(conf, input, + clustersOut, + output, + convergenceDelta, + m, + emitMostLikely, + threshold, + runSequential); + } + } + + /** + * Iterate over the input vectors to produce clusters and, if requested, use the + * results of the final iteration to cluster the input vectors. + * @param input + * the directory pathname for input points + * @param clustersIn + * the directory pathname for initial & computed clusters + * @param output + * the directory pathname for output points + * @param convergenceDelta +* the convergence delta value + * @param maxIterations +* the maximum number of iterations + * @param m +* the fuzzification factor, see +* http://en.wikipedia.org/wiki/Data_clustering#Fuzzy_c-means_clustering + * @param runClustering +* true if points are to be clustered after iterations complete + * @param emitMostLikely +* a boolean if true emit only most likely cluster for each point + * @param threshold +* a double threshold value emits all clusters having greater pdf (emitMostLikely = false) + * @param runSequential if true run in sequential execution mode + */ + public static void run(Configuration conf, + Path input, + Path clustersIn, + Path output, + double convergenceDelta, + int maxIterations, + float m, + boolean runClustering, + boolean emitMostLikely, + double threshold, + boolean runSequential) + throws IOException, ClassNotFoundException, InterruptedException { + Path clustersOut = + buildClusters(conf, input, clustersIn, output, convergenceDelta, maxIterations, m, runSequential); + if (runClustering) { + log.info("Clustering"); + clusterData(conf, + input, + clustersOut, + output, + convergenceDelta, + m, + emitMostLikely, + threshold, + runSequential); + } + } + + /** + * Iterate over the input vectors to produce cluster directories for each iteration + * + * @param input + * the directory pathname for input points + * @param clustersIn + * the file pathname for initial cluster centers + * @param output + * the directory pathname for output points + * @param convergenceDelta + * the convergence delta value + * @param maxIterations + * the maximum number of iterations + * @param m + * the fuzzification factor, see + * http://en.wikipedia.org/wiki/Data_clustering#Fuzzy_c-means_clustering + * @param runSequential if true run in sequential execution mode + * + * @return the Path of the final clusters directory + */ + public static Path buildClusters(Configuration conf, + Path input, + Path clustersIn, + Path output, + double convergenceDelta, + int maxIterations, + float m, + boolean runSequential) + throws IOException, InterruptedException, ClassNotFoundException { + + List<Cluster> clusters = new ArrayList<>(); + FuzzyKMeansUtil.configureWithClusterInfo(conf, clustersIn, clusters); + + if (conf == null) { + conf = new Configuration(); + } + + if (clusters.isEmpty()) { + throw new IllegalStateException("No input clusters found in " + clustersIn + ". Check your -c argument."); + } + + Path priorClustersPath = new Path(output, Cluster.INITIAL_CLUSTERS_DIR); + ClusteringPolicy policy = new FuzzyKMeansClusteringPolicy(m, convergenceDelta); + ClusterClassifier prior = new ClusterClassifier(clusters, policy); + prior.writeToSeqFiles(priorClustersPath); + + if (runSequential) { + ClusterIterator.iterateSeq(conf, input, priorClustersPath, output, maxIterations); + } else { + ClusterIterator.iterateMR(conf, input, priorClustersPath, output, maxIterations); + } + return output; + } + + /** + * Run the job using supplied arguments + * + * @param input + * the directory pathname for input points + * @param clustersIn + * the directory pathname for input clusters + * @param output + * the directory pathname for output points + * @param convergenceDelta +* the convergence delta value + * @param emitMostLikely +* a boolean if true emit only most likely cluster for each point + * @param threshold +* a double threshold value emits all clusters having greater pdf (emitMostLikely = false) + * @param runSequential if true run in sequential execution mode + */ + public static void clusterData(Configuration conf, + Path input, + Path clustersIn, + Path output, + double convergenceDelta, + float m, + boolean emitMostLikely, + double threshold, + boolean runSequential) + throws IOException, ClassNotFoundException, InterruptedException { + + ClusterClassifier.writePolicy(new FuzzyKMeansClusteringPolicy(m, convergenceDelta), clustersIn); + ClusterClassificationDriver.run(conf, input, output, new Path(output, PathDirectory.CLUSTERED_POINTS_DIRECTORY), + threshold, emitMostLikely, runSequential); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java new file mode 100644 index 0000000..25621bb --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/FuzzyKMeansUtil.java @@ -0,0 +1,76 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.fuzzykmeans; + +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.canopy.Canopy; +import org.apache.mahout.clustering.iterator.ClusterWritable; +import org.apache.mahout.clustering.kmeans.Kluster; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; + +final class FuzzyKMeansUtil { + + private FuzzyKMeansUtil() {} + + /** + * Create a list of SoftClusters from whatever type is passed in as the prior + * + * @param conf + * the Configuration + * @param clusterPath + * the path to the prior Clusters + * @param clusters + * a List<Cluster> to put values into + */ + public static void configureWithClusterInfo(Configuration conf, Path clusterPath, List<Cluster> clusters) { + for (Writable value : new SequenceFileDirValueIterable<>(clusterPath, PathType.LIST, + PathFilters.partFilter(), conf)) { + Class<? extends Writable> valueClass = value.getClass(); + + if (valueClass.equals(ClusterWritable.class)) { + ClusterWritable clusterWritable = (ClusterWritable) value; + value = clusterWritable.getValue(); + valueClass = value.getClass(); + } + + if (valueClass.equals(Kluster.class)) { + // get the cluster info + Kluster cluster = (Kluster) value; + clusters.add(new SoftCluster(cluster.getCenter(), cluster.getId(), cluster.getMeasure())); + } else if (valueClass.equals(SoftCluster.class)) { + // get the cluster info + clusters.add((SoftCluster) value); + } else if (valueClass.equals(Canopy.class)) { + // get the cluster info + Canopy canopy = (Canopy) value; + clusters.add(new SoftCluster(canopy.getCenter(), canopy.getId(), canopy.getMeasure())); + } else { + throw new IllegalStateException("Bad value class: " + valueClass); + } + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java new file mode 100644 index 0000000..52fd764 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/fuzzykmeans/SoftCluster.java @@ -0,0 +1,60 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.fuzzykmeans; + +import org.apache.mahout.clustering.kmeans.Kluster; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +public class SoftCluster extends Kluster { + + // For Writable + public SoftCluster() {} + + /** + * Construct a new SoftCluster with the given point as its center + * + * @param center + * the center point + * @param measure + * the DistanceMeasure + */ + public SoftCluster(Vector center, int clusterId, DistanceMeasure measure) { + super(center, clusterId, measure); + } + + @Override + public String asFormatString() { + return this.getIdentifier() + ": " + + this.computeCentroid().asFormatString(); + } + + @Override + public String getIdentifier() { + return (isConverged() ? "SV-" : "SC-") + getId(); + } + + @Override + public double pdf(VectorWritable vw) { + // SoftCluster pdf cannot be calculated out of context. See + // FuzzyKMeansClusterer + throw new UnsupportedOperationException( + "SoftCluster pdf cannot be calculated out of context. See FuzzyKMeansClusterer"); + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java new file mode 100644 index 0000000..07cc7e3 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/AbstractClusteringPolicy.java @@ -0,0 +1,72 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.List; + +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.math.DenseVector; +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; +import org.apache.mahout.math.function.TimesFunction; + +public abstract class AbstractClusteringPolicy implements ClusteringPolicy { + + @Override + public abstract void write(DataOutput out) throws IOException; + + @Override + public abstract void readFields(DataInput in) throws IOException; + + @Override + public Vector select(Vector probabilities) { + int maxValueIndex = probabilities.maxValueIndex(); + Vector weights = new SequentialAccessSparseVector(probabilities.size()); + weights.set(maxValueIndex, 1.0); + return weights; + } + + @Override + public void update(ClusterClassifier posterior) { + // nothing to do in general here + } + + @Override + public Vector classify(Vector data, ClusterClassifier prior) { + List<Cluster> models = prior.getModels(); + int i = 0; + Vector pdfs = new DenseVector(models.size()); + for (Cluster model : models) { + pdfs.set(i++, model.pdf(new VectorWritable(data))); + } + return pdfs.assign(new TimesFunction(), 1.0 / pdfs.zSum()); + } + + @Override + public void close(ClusterClassifier posterior) { + for (Cluster cluster : posterior.getModels()) { + cluster.computeParameters(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java new file mode 100644 index 0000000..fb2db49 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIMapper.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.iterator; + +import java.io.IOException; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.io.WritableComparable; +import org.apache.hadoop.mapreduce.Mapper; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.Vector.Element; +import org.apache.mahout.math.VectorWritable; + +public class CIMapper extends Mapper<WritableComparable<?>,VectorWritable,IntWritable,ClusterWritable> { + + private ClusterClassifier classifier; + private ClusteringPolicy policy; + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + String priorClustersPath = conf.get(ClusterIterator.PRIOR_PATH_KEY); + classifier = new ClusterClassifier(); + classifier.readFromSeqFiles(conf, new Path(priorClustersPath)); + policy = classifier.getPolicy(); + policy.update(classifier); + super.setup(context); + } + + @Override + protected void map(WritableComparable<?> key, VectorWritable value, Context context) throws IOException, + InterruptedException { + Vector probabilities = classifier.classify(value.get()); + Vector selections = policy.select(probabilities); + for (Element el : selections.nonZeroes()) { + classifier.train(el.index(), value.get(), el.get()); + } + } + + @Override + protected void cleanup(Context context) throws IOException, InterruptedException { + List<Cluster> clusters = classifier.getModels(); + ClusterWritable cw = new ClusterWritable(); + for (int index = 0; index < clusters.size(); index++) { + cw.setValue(clusters.get(index)); + context.write(new IntWritable(index), cw); + } + super.cleanup(context); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java new file mode 100644 index 0000000..ca63b0f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CIReducer.java @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.iterator; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Reducer; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; + +public class CIReducer extends Reducer<IntWritable,ClusterWritable,IntWritable,ClusterWritable> { + + private ClusterClassifier classifier; + private ClusteringPolicy policy; + + @Override + protected void reduce(IntWritable key, Iterable<ClusterWritable> values, Context context) throws IOException, + InterruptedException { + Iterator<ClusterWritable> iter = values.iterator(); + Cluster first = iter.next().getValue(); // there must always be at least one + while (iter.hasNext()) { + Cluster cluster = iter.next().getValue(); + first.observe(cluster); + } + List<Cluster> models = new ArrayList<>(); + models.add(first); + classifier = new ClusterClassifier(models, policy); + classifier.close(); + context.write(key, new ClusterWritable(first)); + } + + @Override + protected void setup(Context context) throws IOException, InterruptedException { + Configuration conf = context.getConfiguration(); + String priorClustersPath = conf.get(ClusterIterator.PRIOR_PATH_KEY); + classifier = new ClusterClassifier(); + classifier.readFromSeqFiles(conf, new Path(priorClustersPath)); + policy = classifier.getPolicy(); + policy.update(classifier); + super.setup(context); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java new file mode 100644 index 0000000..c9a0940 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/CanopyClusteringPolicy.java @@ -0,0 +1,52 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.mahout.math.SequentialAccessSparseVector; +import org.apache.mahout.math.Vector; + +@Deprecated +public class CanopyClusteringPolicy extends AbstractClusteringPolicy { + + private double t1; + private double t2; + + @Override + public Vector select(Vector probabilities) { + int maxValueIndex = probabilities.maxValueIndex(); + Vector weights = new SequentialAccessSparseVector(probabilities.size()); + weights.set(maxValueIndex, 1.0); + return weights; + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(t1); + out.writeDouble(t2); + } + + @Override + public void readFields(DataInput in) throws IOException { + this.t1 = in.readDouble(); + this.t2 = in.readDouble(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java new file mode 100644 index 0000000..516177f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterIterator.java @@ -0,0 +1,219 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.fs.FileStatus; +import org.apache.hadoop.fs.FileSystem; +import org.apache.hadoop.fs.Path; +import org.apache.hadoop.io.IntWritable; +import org.apache.hadoop.mapreduce.Job; +import org.apache.hadoop.mapreduce.lib.input.FileInputFormat; +import org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat; +import org.apache.hadoop.mapreduce.lib.output.FileOutputFormat; +import org.apache.hadoop.mapreduce.lib.output.SequenceFileOutputFormat; +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.common.iterator.sequencefile.PathFilters; +import org.apache.mahout.common.iterator.sequencefile.PathType; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable; +import org.apache.mahout.common.iterator.sequencefile.SequenceFileValueIterator; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +import com.google.common.io.Closeables; + +/** + * This is a clustering iterator which works with a set of Vector data and a prior ClusterClassifier which has been + * initialized with a set of models. Its implementation is algorithm-neutral and works for any iterative clustering + * algorithm (currently k-means and fuzzy-k-means) that processes all the input vectors in each iteration. + * The cluster classifier is configured with a ClusteringPolicy to select the desired clustering algorithm. + */ +public final class ClusterIterator { + + public static final String PRIOR_PATH_KEY = "org.apache.mahout.clustering.prior.path"; + + private ClusterIterator() { + } + + /** + * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations + * + * @param data + * a {@code List<Vector>} of input vectors + * @param classifier + * a prior ClusterClassifier + * @param numIterations + * the int number of iterations to perform + * + * @return the posterior ClusterClassifier + */ + public static ClusterClassifier iterate(Iterable<Vector> data, ClusterClassifier classifier, int numIterations) { + ClusteringPolicy policy = classifier.getPolicy(); + for (int iteration = 1; iteration <= numIterations; iteration++) { + for (Vector vector : data) { + // update the policy based upon the prior + policy.update(classifier); + // classification yields probabilities + Vector probabilities = classifier.classify(vector); + // policy selects weights for models given those probabilities + Vector weights = policy.select(probabilities); + // training causes all models to observe data + for (Vector.Element e : weights.nonZeroes()) { + int index = e.index(); + classifier.train(index, vector, weights.get(index)); + } + } + // compute the posterior models + classifier.close(); + } + return classifier; + } + + /** + * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations using a sequential + * implementation + * + * @param conf + * the Configuration + * @param inPath + * a Path to input VectorWritables + * @param priorPath + * a Path to the prior classifier + * @param outPath + * a Path of output directory + * @param numIterations + * the int number of iterations to perform + */ + public static void iterateSeq(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations) + throws IOException { + ClusterClassifier classifier = new ClusterClassifier(); + classifier.readFromSeqFiles(conf, priorPath); + Path clustersOut = null; + int iteration = 1; + while (iteration <= numIterations) { + for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(inPath, PathType.LIST, + PathFilters.logsCRCFilter(), conf)) { + Vector vector = vw.get(); + // classification yields probabilities + Vector probabilities = classifier.classify(vector); + // policy selects weights for models given those probabilities + Vector weights = classifier.getPolicy().select(probabilities); + // training causes all models to observe data + for (Vector.Element e : weights.nonZeroes()) { + int index = e.index(); + classifier.train(index, vector, weights.get(index)); + } + } + // compute the posterior models + classifier.close(); + // update the policy + classifier.getPolicy().update(classifier); + // output the classifier + clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration); + classifier.writeToSeqFiles(clustersOut); + FileSystem fs = FileSystem.get(outPath.toUri(), conf); + iteration++; + if (isConverged(clustersOut, conf, fs)) { + break; + } + } + Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX); + FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn); + } + + /** + * Iterate over data using a prior-trained ClusterClassifier, for a number of iterations using a mapreduce + * implementation + * + * @param conf + * the Configuration + * @param inPath + * a Path to input VectorWritables + * @param priorPath + * a Path to the prior classifier + * @param outPath + * a Path of output directory + * @param numIterations + * the int number of iterations to perform + */ + public static void iterateMR(Configuration conf, Path inPath, Path priorPath, Path outPath, int numIterations) + throws IOException, InterruptedException, ClassNotFoundException { + ClusteringPolicy policy = ClusterClassifier.readPolicy(priorPath); + Path clustersOut = null; + int iteration = 1; + while (iteration <= numIterations) { + conf.set(PRIOR_PATH_KEY, priorPath.toString()); + + String jobName = "Cluster Iterator running iteration " + iteration + " over priorPath: " + priorPath; + Job job = new Job(conf, jobName); + job.setMapOutputKeyClass(IntWritable.class); + job.setMapOutputValueClass(ClusterWritable.class); + job.setOutputKeyClass(IntWritable.class); + job.setOutputValueClass(ClusterWritable.class); + + job.setInputFormatClass(SequenceFileInputFormat.class); + job.setOutputFormatClass(SequenceFileOutputFormat.class); + job.setMapperClass(CIMapper.class); + job.setReducerClass(CIReducer.class); + + FileInputFormat.addInputPath(job, inPath); + clustersOut = new Path(outPath, Cluster.CLUSTERS_DIR + iteration); + priorPath = clustersOut; + FileOutputFormat.setOutputPath(job, clustersOut); + + job.setJarByClass(ClusterIterator.class); + if (!job.waitForCompletion(true)) { + throw new InterruptedException("Cluster Iteration " + iteration + " failed processing " + priorPath); + } + ClusterClassifier.writePolicy(policy, clustersOut); + FileSystem fs = FileSystem.get(outPath.toUri(), conf); + iteration++; + if (isConverged(clustersOut, conf, fs)) { + break; + } + } + Path finalClustersIn = new Path(outPath, Cluster.CLUSTERS_DIR + (iteration - 1) + Cluster.FINAL_ITERATION_SUFFIX); + FileSystem.get(clustersOut.toUri(), conf).rename(clustersOut, finalClustersIn); + } + + /** + * Return if all of the Clusters in the parts in the filePath have converged or not + * + * @param filePath + * the file path to the single file containing the clusters + * @return true if all Clusters are converged + * @throws IOException + * if there was an IO error + */ + private static boolean isConverged(Path filePath, Configuration conf, FileSystem fs) throws IOException { + for (FileStatus part : fs.listStatus(filePath, PathFilters.partFilter())) { + SequenceFileValueIterator<ClusterWritable> iterator = new SequenceFileValueIterator<>( + part.getPath(), true, conf); + while (iterator.hasNext()) { + ClusterWritable value = iterator.next(); + if (!value.getValue().isConverged()) { + Closeables.close(iterator, true); + return false; + } + } + } + return true; + } +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java new file mode 100644 index 0000000..855685f --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusterWritable.java @@ -0,0 +1,56 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; +import org.apache.mahout.clustering.Cluster; + +public class ClusterWritable implements Writable { + + private Cluster value; + + public ClusterWritable(Cluster first) { + value = first; + } + + public ClusterWritable() { + } + + public Cluster getValue() { + return value; + } + + public void setValue(Cluster value) { + this.value = value; + } + + @Override + public void write(DataOutput out) throws IOException { + PolymorphicWritable.write(out, value); + } + + @Override + public void readFields(DataInput in) throws IOException { + value = PolymorphicWritable.read(in, Cluster.class); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java new file mode 100644 index 0000000..6e15838 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicy.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.math.Vector; + +/** + * A ClusteringPolicy captures the semantics of assignment of points to clusters + * + */ +public interface ClusteringPolicy extends Writable { + + /** + * Classify the data vector given the classifier's models + * + * @param data + * a data Vector + * @param prior + * a prior ClusterClassifier + * @return a Vector of probabilities that the data is described by each of the + * models + */ + Vector classify(Vector data, ClusterClassifier prior); + + /** + * Return a vector of weights for each of the models given those probabilities + * + * @param probabilities + * a Vector of pdfs + * @return a Vector of weights + */ + Vector select(Vector probabilities); + + /** + * Update the policy with the given classifier + * + * @param posterior + * a ClusterClassifier + */ + void update(ClusterClassifier posterior); + + /** + * Close the policy using the classifier's models + * + * @param posterior + * a posterior ClusterClassifier + */ + void close(ClusterClassifier posterior); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java new file mode 100644 index 0000000..f69442d --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/ClusteringPolicyWritable.java @@ -0,0 +1,55 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.io.Writable; +import org.apache.mahout.classifier.sgd.PolymorphicWritable; + +public class ClusteringPolicyWritable implements Writable { + + private ClusteringPolicy value; + + public ClusteringPolicyWritable(ClusteringPolicy policy) { + this.value = policy; + } + + public ClusteringPolicyWritable() { + } + + public ClusteringPolicy getValue() { + return value; + } + + public void setValue(ClusteringPolicy value) { + this.value = value; + } + + @Override + public void write(DataOutput out) throws IOException { + PolymorphicWritable.write(out, value); + } + + @Override + public void readFields(DataInput in) throws IOException { + value = PolymorphicWritable.read(in, ClusteringPolicy.class); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java new file mode 100644 index 0000000..f61aa27 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/DistanceMeasureCluster.java @@ -0,0 +1,91 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.hadoop.conf.Configuration; +import org.apache.mahout.clustering.AbstractCluster; +import org.apache.mahout.clustering.Model; +import org.apache.mahout.common.ClassUtils; +import org.apache.mahout.common.distance.DistanceMeasure; +import org.apache.mahout.math.Vector; +import org.apache.mahout.math.VectorWritable; + +public class DistanceMeasureCluster extends AbstractCluster { + + private DistanceMeasure measure; + + public DistanceMeasureCluster(Vector point, int id, DistanceMeasure measure) { + super(point, id); + this.measure = measure; + } + + public DistanceMeasureCluster() { + } + + @Override + public void configure(Configuration job) { + if (measure != null) { + measure.configure(job); + } + } + + @Override + public void readFields(DataInput in) throws IOException { + String dm = in.readUTF(); + this.measure = ClassUtils.instantiateAs(dm, DistanceMeasure.class); + super.readFields(in); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeUTF(measure.getClass().getName()); + super.write(out); + } + + @Override + public double pdf(VectorWritable vw) { + return 1 / (1 + measure.distance(vw.get(), getCenter())); + } + + @Override + public Model<VectorWritable> sampleFromPosterior() { + return new DistanceMeasureCluster(getCenter(), getId(), measure); + } + + public DistanceMeasure getMeasure() { + return measure; + } + + /** + * @param measure + * the measure to set + */ + public void setMeasure(DistanceMeasure measure) { + this.measure = measure; + } + + @Override + public String getIdentifier() { + return "DMC:" + getId(); + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java new file mode 100644 index 0000000..b4e41b6 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/FuzzyKMeansClusteringPolicy.java @@ -0,0 +1,90 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; +import org.apache.mahout.clustering.fuzzykmeans.FuzzyKMeansClusterer; +import org.apache.mahout.clustering.fuzzykmeans.SoftCluster; +import org.apache.mahout.math.Vector; + +/** + * This is a probability-weighted clustering policy, suitable for fuzzy k-means + * clustering + * + */ +public class FuzzyKMeansClusteringPolicy extends AbstractClusteringPolicy { + + private double m = 2; + private double convergenceDelta = 0.05; + + public FuzzyKMeansClusteringPolicy() { + } + + public FuzzyKMeansClusteringPolicy(double m, double convergenceDelta) { + this.m = m; + this.convergenceDelta = convergenceDelta; + } + + @Override + public Vector select(Vector probabilities) { + return probabilities; + } + + @Override + public Vector classify(Vector data, ClusterClassifier prior) { + Collection<SoftCluster> clusters = new ArrayList<>(); + List<Double> distances = new ArrayList<>(); + for (Cluster model : prior.getModels()) { + SoftCluster sc = (SoftCluster) model; + clusters.add(sc); + distances.add(sc.getMeasure().distance(data, sc.getCenter())); + } + FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer(); + fuzzyKMeansClusterer.setM(m); + return fuzzyKMeansClusterer.computePi(clusters, distances); + } + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(m); + out.writeDouble(convergenceDelta); + } + + @Override + public void readFields(DataInput in) throws IOException { + this.m = in.readDouble(); + this.convergenceDelta = in.readDouble(); + } + + @Override + public void close(ClusterClassifier posterior) { + for (Cluster cluster : posterior.getModels()) { + ((org.apache.mahout.clustering.kmeans.Kluster) cluster).calculateConvergence(convergenceDelta); + cluster.computeParameters(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java new file mode 100644 index 0000000..1cc9faf --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/iterator/KMeansClusteringPolicy.java @@ -0,0 +1,64 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.mahout.clustering.iterator; + +import java.io.DataInput; +import java.io.DataOutput; +import java.io.IOException; + +import org.apache.mahout.clustering.Cluster; +import org.apache.mahout.clustering.classify.ClusterClassifier; + +/** + * This is a simple maximum likelihood clustering policy, suitable for k-means + * clustering + * + */ +public class KMeansClusteringPolicy extends AbstractClusteringPolicy { + + public KMeansClusteringPolicy() { + } + + public KMeansClusteringPolicy(double convergenceDelta) { + this.convergenceDelta = convergenceDelta; + } + + private double convergenceDelta = 0.001; + + @Override + public void write(DataOutput out) throws IOException { + out.writeDouble(convergenceDelta); + } + + @Override + public void readFields(DataInput in) throws IOException { + this.convergenceDelta = in.readDouble(); + } + + @Override + public void close(ClusterClassifier posterior) { + boolean allConverged = true; + for (Cluster cluster : posterior.getModels()) { + org.apache.mahout.clustering.kmeans.Kluster kluster = (org.apache.mahout.clustering.kmeans.Kluster) cluster; + boolean converged = kluster.calculateConvergence(convergenceDelta); + allConverged = allConverged && converged; + cluster.computeParameters(); + } + + } + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java new file mode 100644 index 0000000..96c4082 --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/IKernelProfile.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.kernel; + +public interface IKernelProfile { + + /** + * @return the calculated dervative value of the kernel + */ + double calculateDerivativeValue(double distance, double h); + +} http://git-wip-us.apache.org/repos/asf/mahout/blob/5eda9e1f/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java ---------------------------------------------------------------------- diff --git a/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java new file mode 100644 index 0000000..46909bb --- /dev/null +++ b/community/mahout-mr/src/main/java/org/apache/mahout/clustering/kernel/TriangularKernelProfile.java @@ -0,0 +1,27 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.mahout.clustering.kernel; + +public class TriangularKernelProfile implements IKernelProfile { + + @Override + public double calculateDerivativeValue(double distance, double h) { + return distance < h ? 1.0 : 0.0; + } + +}
