Author: pranjan
Date: Sun Feb 26 17:02:20 2012
New Revision: 1293874
URL: http://svn.apache.org/viewvc?rev=1293874&view=rev
Log:
MAHOUT-931, MAHOUT-929. Added emitMostLikely and threshold based outlier
removal capability in ClusterClassificationDriver.
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1293874&r1=1293873&r2=1293874&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
Sun Feb 26 17:02:20 2012
@@ -46,6 +46,7 @@ import org.apache.mahout.common.iterator
import
org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterable;
import
org.apache.mahout.common.iterator.sequencefile.SequenceFileDirValueIterator;
import org.apache.mahout.math.Vector;
+import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.VectorWritable;
/**
@@ -63,8 +64,10 @@ public class ClusterClassificationDriver
addInputOption();
addOutputOption();
addOption(DefaultOptionCreator.methodOption().create());
- addOption(DefaultOptionCreator.clustersInOption()
- .withDescription("The input centroids, as Vectors. Must be a
SequenceFile of Writable, Cluster/Canopy.")
+ addOption(DefaultOptionCreator
+ .clustersInOption()
+ .withDescription(
+ "The input centroids, as Vectors. Must be a SequenceFile of
Writable, Cluster/Canopy.")
.create());
if (parseArguments(args) == null) {
@@ -77,16 +80,19 @@ public class ClusterClassificationDriver
if (getConf() == null) {
setConf(new Configuration());
}
- Path clustersIn = new
Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
- boolean runSequential =
getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
- DefaultOptionCreator.SEQUENTIAL_METHOD);
+ Path clustersIn = new Path(
+ getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+ boolean runSequential = getOption(DefaultOptionCreator.METHOD_OPTION)
+ .equalsIgnoreCase(DefaultOptionCreator.SEQUENTIAL_METHOD);
double clusterClassificationThreshold = 0.0;
if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
- clusterClassificationThreshold =
Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+ clusterClassificationThreshold = Double
+ .parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
}
- run(input, clustersIn, output, clusterClassificationThreshold,
runSequential);
+ run(input, clustersIn, output, clusterClassificationThreshold, true,
+ runSequential);
return 0;
}
@@ -97,7 +103,8 @@ public class ClusterClassificationDriver
private ClusterClassificationDriver() {}
public static void main(String[] args) throws Exception {
- ToolRunner.run(new Configuration(), new ClusterClassificationDriver(),
args);
+ ToolRunner
+ .run(new Configuration(), new ClusterClassificationDriver(), args);
}
/**
@@ -117,27 +124,36 @@ public class ClusterClassificationDriver
* classified for the cluster.
* @param runSequential
* Run the process sequentially or in a mapreduce way.
+ * @param runSequential
* @throws IOException
* @throws InterruptedException
* @throws ClassNotFoundException
*/
- public static void run(Path input, Path clusteringOutputPath, Path output,
Double clusterClassificationThreshold,
- boolean runSequential) throws IOException, InterruptedException,
ClassNotFoundException {
+ public static void run(Path input, Path clusteringOutputPath, Path output,
+ Double clusterClassificationThreshold, boolean emitMostLikely,
+ boolean runSequential) throws IOException, InterruptedException,
+ ClassNotFoundException {
if (runSequential) {
- classifyClusterSeq(input, clusteringOutputPath, output,
clusterClassificationThreshold);
+ classifyClusterSeq(input, clusteringOutputPath, output,
+ clusterClassificationThreshold, emitMostLikely);
} else {
Configuration conf = new Configuration();
- classifyClusterMR(conf, input, clusteringOutputPath, output,
clusterClassificationThreshold);
+ classifyClusterMR(conf, input, clusteringOutputPath, output,
+ clusterClassificationThreshold, emitMostLikely);
}
}
- private static void classifyClusterSeq(Path input, Path clusters, Path
output, Double clusterClassificationThreshold)
+ private static void classifyClusterSeq(Path input, Path clusters,
+ Path output, Double clusterClassificationThreshold, boolean
emitMostLikely)
throws IOException {
List<Cluster> clusterModels = populateClusterModels(clusters);
- ClusteringPolicy policy =
ClusterClassifier.readPolicy(finalClustersPath(clusters));
- ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels,
policy);
- selectCluster(input, clusterModels, clusterClassifier, output,
clusterClassificationThreshold);
+ ClusteringPolicy policy = ClusterClassifier
+ .readPolicy(finalClustersPath(clusters));
+ ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels,
+ policy);
+ selectCluster(input, clusterModels, clusterClassifier, output,
+ clusterClassificationThreshold, emitMostLikely);
}
@@ -149,12 +165,14 @@ public class ClusterClassificationDriver
* @return The list of clusters found by the clustering.
* @throws IOException
*/
- private static List<Cluster> populateClusterModels(Path clusterOutputPath)
throws IOException {
+ private static List<Cluster> populateClusterModels(Path clusterOutputPath)
+ throws IOException {
List<Cluster> clusterModels = new ArrayList<Cluster>();
Cluster cluster = null;
Path finalClustersPath = finalClustersPath(clusterOutputPath);
- Iterator<?> it = new
SequenceFileDirValueIterator<Writable>(finalClustersPath, PathType.LIST,
- PathFilters.partFilter(), null, false, new Configuration());
+ Iterator<?> it = new SequenceFileDirValueIterator<Writable>(
+ finalClustersPath, PathType.LIST, PathFilters.partFilter(), null,
+ false, new Configuration());
while (it.hasNext()) {
cluster = (Cluster) it.next();
clusterModels.add(cluster);
@@ -162,9 +180,12 @@ public class ClusterClassificationDriver
return clusterModels;
}
- private static Path finalClustersPath(Path clusterOutputPath) throws
IOException {
- FileSystem fileSystem = clusterOutputPath.getFileSystem(new
Configuration());
- FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
PathFilters.finalPartFilter());
+ private static Path finalClustersPath(Path clusterOutputPath)
+ throws IOException {
+ FileSystem fileSystem = clusterOutputPath
+ .getFileSystem(new Configuration());
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
+ PathFilters.finalPartFilter());
Path finalClustersPath = clusterFiles[0].getPath();
return finalClustersPath;
}
@@ -181,45 +202,84 @@ public class ClusterClassificationDriver
* @param output
* the path to store classified data
* @param clusterClassificationThreshold
+ * @param emitMostLikely
+ * TODO
* @throws IOException
*/
- private static void selectCluster(Path input, List<Cluster> clusterModels,
ClusterClassifier clusterClassifier,
- Path output, Double clusterClassificationThreshold) throws IOException {
+ private static void selectCluster(Path input, List<Cluster> clusterModels,
+ ClusterClassifier clusterClassifier, Path output,
+ Double clusterClassificationThreshold, boolean emitMostLikely)
+ throws IOException {
Configuration conf = new Configuration();
- SequenceFile.Writer writer = new
SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(output,
- "part-m-" + 0), IntWritable.class, VectorWritable.class);
- for (VectorWritable vw : new
SequenceFileDirValueIterable<VectorWritable>(input, PathType.LIST,
- PathFilters.logsCRCFilter(), conf)) {
+ SequenceFile.Writer writer = new SequenceFile.Writer(
+ input.getFileSystem(conf), conf, new Path(output, "part-m-" + 0),
+ IntWritable.class, VectorWritable.class);
+ for (VectorWritable vw : new SequenceFileDirValueIterable<VectorWritable>(
+ input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
Vector pdfPerCluster = clusterClassifier.classify(vw.get());
if (shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
- int maxValueIndex = pdfPerCluster.maxValueIndex();
- Cluster cluster = clusterModels.get(maxValueIndex);
- writer.append(new IntWritable(cluster.getId()), vw);
+ classifyAndWrite(clusterModels, clusterClassificationThreshold,
+ emitMostLikely, writer, vw, pdfPerCluster);
}
}
writer.close();
}
+ private static void classifyAndWrite(List<Cluster> clusterModels,
+ Double clusterClassificationThreshold, boolean emitMostLikely,
+ SequenceFile.Writer writer, VectorWritable vw, Vector pdfPerCluster)
+ throws IOException {
+ if (emitMostLikely) {
+ int maxValueIndex = pdfPerCluster.maxValueIndex();
+ write(clusterModels, writer, vw, maxValueIndex);
+ } else {
+ writeAllAboveThreshold(clusterModels, clusterClassificationThreshold,
+ writer, vw, pdfPerCluster);
+ }
+ }
+
+ private static void writeAllAboveThreshold(List<Cluster> clusterModels,
+ Double clusterClassificationThreshold, SequenceFile.Writer writer,
+ VectorWritable vw, Vector pdfPerCluster) throws IOException {
+ Iterator<Element> iterateNonZero = pdfPerCluster.iterateNonZero();
+ while (iterateNonZero.hasNext()) {
+ Element pdf = iterateNonZero.next();
+ if (pdf.get() >= clusterClassificationThreshold) {
+ int clusterIndex = pdf.index();
+ write(clusterModels, writer, vw, clusterIndex);
+ }
+ }
+ }
+
+ private static void write(List<Cluster> clusterModels,
+ SequenceFile.Writer writer, VectorWritable vw, int maxValueIndex)
+ throws IOException {
+ Cluster cluster = clusterModels.get(maxValueIndex);
+ writer.append(new IntWritable(cluster.getId()), vw);
+ }
+
/**
* Decides whether the vector should be classified or not based on the max
pdf
* value of the clusters and threshold value.
*
- * @param pdfPerCluster
- * pdf of vector belonging to different clusters.
- * @param clusterClassificationThreshold
- * threshold below which the vectors won't be classified.
* @return whether the vector should be classified or not.
*/
- private static boolean shouldClassify(Vector pdfPerCluster, Double
clusterClassificationThreshold) {
- return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+ private static boolean shouldClassify(Vector pdfPerCluster,
+ Double clusterClassificationThreshold) {
+ boolean isMaxPDFGreatherThanThreshold = pdfPerCluster.maxValue() >=
clusterClassificationThreshold;
+ return isMaxPDFGreatherThanThreshold;
}
- private static void classifyClusterMR(Configuration conf, Path input, Path
clustersIn, Path output,
- Double clusterClassificationThreshold) throws IOException,
InterruptedException, ClassNotFoundException {
- Job job = new Job(conf, "Cluster Classification Driver running over input:
" + input);
+ private static void classifyClusterMR(Configuration conf, Path input,
+ Path clustersIn, Path output, Double clusterClassificationThreshold,
+ boolean emitMostLikely) throws IOException, InterruptedException,
+ ClassNotFoundException {
+ Job job = new Job(conf,
+ "Cluster Classification Driver running over input: " + input);
job.setJarByClass(ClusterClassificationDriver.class);
- conf.setFloat(OUTLIER_REMOVAL_THRESHOLD,
clusterClassificationThreshold.floatValue());
+ conf.setFloat(OUTLIER_REMOVAL_THRESHOLD,
+ clusterClassificationThreshold.floatValue());
conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, input.toString());
@@ -235,7 +295,8 @@ public class ClusterClassificationDriver
FileInputFormat.addInputPath(job, input);
FileOutputFormat.setOutputPath(job, output);
if (!job.waitForCompletion(true)) {
- throw new InterruptedException("Cluster Classification Driver Job failed
processing " + input);
+ throw new InterruptedException(
+ "Cluster Classification Driver Job failed processing " + input);
}
}
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1293874&r1=1293873&r2=1293874&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Sun Feb 26 17:02:20 2012
@@ -47,7 +47,8 @@ import com.google.common.collect.Lists;
public class ClusterClassificationDriverTest extends MahoutTestCase {
- private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4,
4}, {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
+ private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 4},
+ {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
private FileSystem fs;
@@ -97,7 +98,8 @@ public class ClusterClassificationDriver
conf = new Configuration();
- ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath,
"file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
runClustering(pointsPath, conf);
runClassificationWithoutOutlierRemoval(conf);
collectVectorsForAssertion();
@@ -114,35 +116,42 @@ public class ClusterClassificationDriver
conf = new Configuration();
- ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath,
"file1"), fs, conf);
+ ClusteringTestUtils.writePointsToFile(points,
+ new Path(pointsPath, "file1"), fs, conf);
runClustering(pointsPath, conf);
runClassificationWithOutlierRemoval(conf);
collectVectorsForAssertion();
assertVectorsWithOutlierRemoval();
}
- private void runClustering(Path pointsPath, Configuration conf) throws
IOException, InterruptedException,
- ClassNotFoundException {
- CanopyDriver.run(conf, pointsPath, clusteringOutputPath, new
ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
+ private void runClustering(Path pointsPath, Configuration conf)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ CanopyDriver.run(conf, pointsPath, clusteringOutputPath,
+ new ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
Path finalClustersPath = new Path(clusteringOutputPath,
"clusters-0-final");
- ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
finalClustersPath);
+ ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
+ finalClustersPath);
}
- private void runClassificationWithoutOutlierRemoval(Configuration conf)
throws IOException, InterruptedException,
- ClassNotFoundException {
- ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
classifiedOutputPath, 0.0, true);
+ private void runClassificationWithoutOutlierRemoval(Configuration conf)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
+ classifiedOutputPath, 0.0, true, true);
}
- private void runClassificationWithOutlierRemoval(Configuration conf2) throws
IOException, InterruptedException,
- ClassNotFoundException {
- ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
classifiedOutputPath, 0.73, true);
+ private void runClassificationWithOutlierRemoval(Configuration conf2)
+ throws IOException, InterruptedException, ClassNotFoundException {
+ ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
+ classifiedOutputPath, 0.73, true, true);
}
private void collectVectorsForAssertion() throws IOException {
- Path[] partFilePaths =
FileUtil.stat2Paths(fs.globStatus(classifiedOutputPath));
+ Path[] partFilePaths = FileUtil.stat2Paths(fs
+ .globStatus(classifiedOutputPath));
FileStatus[] listStatus = fs.listStatus(partFilePaths);
for (FileStatus partFile : listStatus) {
- SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
partFile.getPath(), conf);
+ SequenceFile.Reader classifiedVectors = new SequenceFile.Reader(fs,
+ partFile.getPath(), conf);
Writable clusterIdAsKey = new IntWritable();
VectorWritable point = new VectorWritable();
while (classifiedVectors.next(clusterIdAsKey, point)) {
@@ -176,30 +185,33 @@ public class ClusterClassificationDriver
private void assertThirdClusterWithoutOutlierRemoval() {
Assert.assertEquals(2, thirdCluster.size());
for (Vector vector : thirdCluster) {
- Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}",
"{1:8.0,0:8.0}"}, vector.asFormatString()));
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}",
+ "{1:8.0,0:8.0}"}, vector.asFormatString()));
}
}
private void assertSecondClusterWithoutOutlierRemoval() {
Assert.assertEquals(4, secondCluster.size());
for (Vector vector : secondCluster) {
- Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}",
"{1:4.0,0:5.0}", "{1:5.0,0:4.0}",
- "{1:5.0,0:5.0}"}, vector.asFormatString()));
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}",
+ "{1:4.0,0:5.0}", "{1:5.0,0:4.0}", "{1:5.0,0:5.0}"},
+ vector.asFormatString()));
}
}
private void assertFirstClusterWithoutOutlierRemoval() {
Assert.assertEquals(3, firstCluster.size());
for (Vector vector : firstCluster) {
- Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}",
"{1:1.0,0:2.0}", "{1:2.0,0:1.0}"},
- vector.asFormatString()));
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}",
+ "{1:1.0,0:2.0}", "{1:2.0,0:1.0}"}, vector.asFormatString()));
}
}
private void assertThirdClusterWithOutlierRemoval() {
Assert.assertEquals(1, thirdCluster.size());
for (Vector vector : thirdCluster) {
- Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}"},
vector.asFormatString()));
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}"},
+ vector.asFormatString()));
}
}
@@ -210,7 +222,8 @@ public class ClusterClassificationDriver
private void assertFirstClusterWithOutlierRemoval() {
Assert.assertEquals(1, firstCluster.size());
for (Vector vector : firstCluster) {
- Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}"},
vector.asFormatString()));
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}"},
+ vector.asFormatString()));
}
}