Author: jeastman
Date: Thu Feb 23 02:48:03 2012
New Revision: 1292629
URL: http://svn.apache.org/viewvc?rev=1292629&view=rev
Log:
MAHOUT-933: Fixed undetected defects introduced by earlier commit.
I will run all the unit tests before every check-in
I will run all the unit tests before every check-in
I will run all the unit tests before every check-in
...
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
Thu Feb 23 02:48:03 2012
@@ -170,7 +170,7 @@ public class ClusterClassifier extends A
}
public void writeToSeqFiles(Path path) throws IOException {
- writePolicy(path);
+ writePolicy(policy, path);
Configuration config = new Configuration();
FileSystem fs = FileSystem.get(path.toUri(), config);
SequenceFile.Writer writer = null;
@@ -202,7 +202,7 @@ public class ClusterClassifier extends A
this.policy = readPolicy(path);
}
- private ClusteringPolicy readPolicy(Path path) throws IOException {
+ public static ClusteringPolicy readPolicy(Path path) throws IOException {
Path policyPath = new Path(path, "_policy");
Configuration config = new Configuration();
FileSystem fs = FileSystem.get(policyPath.toUri(), config);
@@ -213,7 +213,7 @@ public class ClusterClassifier extends A
return cpw.getValue();
}
- protected void writePolicy(Path path) throws IOException {
+ public static void writePolicy(ClusteringPolicy policy, Path path) throws
IOException {
Path policyPath = new Path(path, "_policy");
Configuration config = new Configuration();
FileSystem fs = FileSystem.get(policyPath.toUri(), config);
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
Thu Feb 23 02:48:03 2012
@@ -146,7 +146,6 @@ public class ClusterIterator {
InterruptedException, ClassNotFoundException {
Configuration conf = new Configuration();
HadoopUtil.delete(conf, outPath);
- ClusterClassifier classifier = new ClusterClassifier(policy);
for (int iteration = 1; iteration <= numIterations; iteration++) {
conf.set(PRIOR_PATH_KEY, priorPath.toString());
@@ -172,7 +171,7 @@ public class ClusterIterator {
if (!job.waitForCompletion(true)) {
throw new InterruptedException("Cluster Iteration " + iteration + "
failed processing " + priorPath);
}
- classifier.writePolicy(clustersOut);
+ ClusterClassifier.writePolicy(policy, clustersOut);
FileSystem fs = FileSystem.get(outPath.toUri(), conf);
if (isConverged(clustersOut, conf, fs)) {
break;
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
Thu Feb 23 02:48:03 2012
@@ -71,6 +71,20 @@ public class FuzzyKMeansClusteringPolicy
return probabilities;
}
+ @Override
+ public Vector classify(Vector data, List<Cluster> models) {
+ Collection<SoftCluster> clusters = Lists.newArrayList();
+ List<Double> distances = Lists.newArrayList();
+ for (Cluster model : models) {
+ SoftCluster sc = (SoftCluster) model;
+ clusters.add(sc);
+ distances.add(sc.getMeasure().distance(data, sc.getCenter()));
+ }
+ FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer();
+ fuzzyKMeansClusterer.setM(m);
+ return fuzzyKMeansClusterer.computePi(clusters, distances);
+ }
+
/*
* (non-Javadoc)
*
@@ -93,18 +107,4 @@ public class FuzzyKMeansClusteringPolicy
this.convergenceDelta = in.readDouble();
}
- @Override
- public Vector classify(Vector data, List<Cluster> models) {
- Collection<SoftCluster> clusters = Lists.newArrayList();
- List<Double> distances = Lists.newArrayList();
- for (Cluster model : models) {
- SoftCluster sc = (SoftCluster) model;
- clusters.add(sc);
- distances.add(sc.getMeasure().distance(data, sc.getCenter()));
- }
- FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer();
- fuzzyKMeansClusterer.setM(m);
- return fuzzyKMeansClusterer.computePi(clusters, distances);
- }
-
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
Thu Feb 23 02:48:03 2012
@@ -39,6 +39,7 @@ import org.apache.hadoop.mapreduce.lib.o
import org.apache.hadoop.util.ToolRunner;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.ClusteringPolicy;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.common.AbstractJob;
import org.apache.mahout.common.commandline.DefaultOptionCreator;
@@ -50,189 +51,194 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
/**
- * Classifies the vectors into different clusters found by the clustering
algorithm.
+ * Classifies the vectors into different clusters found by the clustering
+ * algorithm.
*/
public class ClusterClassificationDriver extends AbstractJob {
-
- /**
- * CLI to run Cluster Classification Driver.
- */
- @Override
- public int run(String[] args) throws Exception {
-
- addInputOption();
- addOutputOption();
- addOption(DefaultOptionCreator.methodOption().create());
- addOption(DefaultOptionCreator.clustersInOption()
- .withDescription("The input centroids, as Vectors. Must be
a SequenceFile of Writable, Cluster/Canopy.")
- .create());
-
- if (parseArguments(args) == null) {
- return -1;
- }
-
- Path input = getInputPath();
- Path output = getOutputPath();
-
- if (getConf() == null) {
- setConf(new Configuration());
- }
- Path clustersIn = new
Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
- boolean runSequential =
getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
- DefaultOptionCreator.SEQUENTIAL_METHOD);
-
- double clusterClassificationThreshold = 0.0;
- if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
- clusterClassificationThreshold =
Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
- }
-
- run(input, clustersIn, output, clusterClassificationThreshold ,
runSequential);
-
- return 0;
- }
-
- /**
- * Constructor to be used by the ToolRunner.
- */
- private ClusterClassificationDriver() {}
-
- public static void main(String[] args) throws Exception {
- ToolRunner.run(new Configuration(), new
ClusterClassificationDriver(), args);
- }
-
- /**
- * Uses {@link ClusterClassifier} to classify input vectors into
their respective clusters.
- *
- * @param input
- * the input vectors
- * @param clusteringOutputPath
- * the output path of clustering ( it reads clusters-*-final
file from here )
- * @param output
- * the location to store the classified vectors
- * @param clusterClassificationThreshold
- * the threshold value of probability distribution function
from 0.0 to 1.0.
- * Any vector with pdf less that this threshold will not be
classified for the cluster.
- * @param runSequential
- * Run the process sequentially or in a mapreduce way.
- * @throws IOException
- * @throws InterruptedException
- * @throws ClassNotFoundException
- */
- public static void run(Path input, Path clusteringOutputPath, Path
output, Double clusterClassificationThreshold, boolean runSequential) throws
IOException,
-
InterruptedException,
-
ClassNotFoundException {
- if (runSequential) {
- classifyClusterSeq(input, clusteringOutputPath, output,
clusterClassificationThreshold);
- } else {
- Configuration conf = new Configuration();
- classifyClusterMR(conf, input, clusteringOutputPath, output,
clusterClassificationThreshold);
- }
-
- }
-
- private static void classifyClusterSeq(Path input, Path clusters,
Path output, Double clusterClassificationThreshold) throws IOException {
- List<Cluster> clusterModels = populateClusterModels(clusters);
- ClusterClassifier clusterClassifier = new
ClusterClassifier(clusterModels, null);
- selectCluster(input, clusterModels, clusterClassifier, output,
clusterClassificationThreshold);
-
- }
-
- /**
- * Populates a list with clusters present in clusters-*-final
directory.
- *
- * @param clusterOutputPath
- * The output path of the clustering.
- * @return
- * The list of clusters found by the clustering.
- * @throws IOException
- */
- private static List<Cluster> populateClusterModels(Path clusterOutputPath)
throws IOException {
- List<Cluster> clusterModels = new ArrayList<Cluster>();
+
+ /**
+ * CLI to run Cluster Classification Driver.
+ */
+ @Override
+ public int run(String[] args) throws Exception {
+
+ addInputOption();
+ addOutputOption();
+ addOption(DefaultOptionCreator.methodOption().create());
+ addOption(DefaultOptionCreator.clustersInOption()
+ .withDescription("The input centroids, as Vectors. Must be a
SequenceFile of Writable, Cluster/Canopy.")
+ .create());
+
+ if (parseArguments(args) == null) {
+ return -1;
+ }
+
+ Path input = getInputPath();
+ Path output = getOutputPath();
+
+ if (getConf() == null) {
+ setConf(new Configuration());
+ }
+ Path clustersIn = new
Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+ boolean runSequential =
getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
+ DefaultOptionCreator.SEQUENTIAL_METHOD);
+
+ double clusterClassificationThreshold = 0.0;
+ if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
+ clusterClassificationThreshold =
Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+ }
+
+ run(input, clustersIn, output, clusterClassificationThreshold,
runSequential);
+
+ return 0;
+ }
+
+ /**
+ * Constructor to be used by the ToolRunner.
+ */
+ private ClusterClassificationDriver() {}
+
+ public static void main(String[] args) throws Exception {
+ ToolRunner.run(new Configuration(), new ClusterClassificationDriver(),
args);
+ }
+
+ /**
+ * Uses {@link ClusterClassifier} to classify input vectors into their
+ * respective clusters.
+ *
+ * @param input
+ * the input vectors
+ * @param clusteringOutputPath
+ * the output path of clustering ( it reads clusters-*-final file
+ * from here )
+ * @param output
+ * the location to store the classified vectors
+ * @param clusterClassificationThreshold
+ * the threshold value of probability distribution function from 0.0
+ * to 1.0. Any vector with pdf less that this threshold will not be
+ * classified for the cluster.
+ * @param runSequential
+ * Run the process sequentially or in a mapreduce way.
+ * @throws IOException
+ * @throws InterruptedException
+ * @throws ClassNotFoundException
+ */
+ public static void run(Path input, Path clusteringOutputPath, Path output,
Double clusterClassificationThreshold,
+ boolean runSequential) throws IOException, InterruptedException,
ClassNotFoundException {
+ if (runSequential) {
+ classifyClusterSeq(input, clusteringOutputPath, output,
clusterClassificationThreshold);
+ } else {
Configuration conf = new Configuration();
- Cluster cluster = null;
- FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
- FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
PathFilters.finalPartFilter());
- Iterator<?> it = new
SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
-
PathType.LIST,
-
PathFilters.partFilter(),
- null,
- false,
- conf);
- while (it.hasNext()) {
- cluster = (Cluster) it.next();
- clusterModels.add(cluster);
- }
- return clusterModels;
+ classifyClusterMR(conf, input, clusteringOutputPath, output,
clusterClassificationThreshold);
}
-
- /**
- * Classifies the vector into its respective cluster.
- *
- * @param input
- * the path containing the input vector.
- * @param clusterModels
- * the clusters
- * @param clusterClassifier
- * used to classify the vectors into different clusters
- * @param output
- * the path to store classified data
- * @param clusterClassificationThreshold
- * @throws IOException
- */
- private static void selectCluster(Path input, List<Cluster>
clusterModels, ClusterClassifier clusterClassifier, Path output, Double
clusterClassificationThreshold) throws IOException {
- Configuration conf = new Configuration();
- SequenceFile.Writer writer = new
SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(
- output, "part-m-" + 0), IntWritable.class,
- VectorWritable.class);
- for (VectorWritable vw : new
SequenceFileDirValueIterable<VectorWritable>(
- input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
- Vector pdfPerCluster = clusterClassifier.classify(vw.get());
- if(shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
- int maxValueIndex = pdfPerCluster.maxValueIndex();
- Cluster cluster = clusterModels.get(maxValueIndex);
- writer.append(new IntWritable(cluster.getId()), vw);
- }
- }
- writer.close();
+
+ }
+
+ private static void classifyClusterSeq(Path input, Path clusters, Path
output, Double clusterClassificationThreshold)
+ throws IOException {
+ List<Cluster> clusterModels = populateClusterModels(clusters);
+ ClusteringPolicy policy =
ClusterClassifier.readPolicy(finalClustersPath(clusters));
+ ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels,
policy);
+ selectCluster(input, clusterModels, clusterClassifier, output,
clusterClassificationThreshold);
+
+ }
+
+ /**
+ * Populates a list with clusters present in clusters-*-final directory.
+ *
+ * @param clusterOutputPath
+ * The output path of the clustering.
+ * @return The list of clusters found by the clustering.
+ * @throws IOException
+ */
+ private static List<Cluster> populateClusterModels(Path clusterOutputPath)
throws IOException {
+ List<Cluster> clusterModels = new ArrayList<Cluster>();
+ Cluster cluster = null;
+ Path finalClustersPath = finalClustersPath(clusterOutputPath);
+ Iterator<?> it = new
SequenceFileDirValueIterator<Writable>(finalClustersPath, PathType.LIST,
+ PathFilters.partFilter(), null, false, new Configuration());
+ while (it.hasNext()) {
+ cluster = (Cluster) it.next();
+ clusterModels.add(cluster);
}
-
- /**
- * Decides whether the vector should be classified or not based on
the max pdf value of the clusters and threshold value.
- *
- * @param pdfPerCluster
- * pdf of vector belonging to different clusters.
- * @param clusterClassificationThreshold
- * threshold below which the vectors won't be classified.
- * @return whether the vector should be classified or not.
- */
- private static boolean shouldClassify(Vector pdfPerCluster, Double
clusterClassificationThreshold) {
- return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+ return clusterModels;
+ }
+
+ private static Path finalClustersPath(Path clusterOutputPath) throws
IOException {
+ FileSystem fileSystem = clusterOutputPath.getFileSystem(new
Configuration());
+ FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
PathFilters.finalPartFilter());
+ Path finalClustersPath = clusterFiles[0].getPath();
+ return finalClustersPath;
+ }
+
+ /**
+ * Classifies the vector into its respective cluster.
+ *
+ * @param input
+ * the path containing the input vector.
+ * @param clusterModels
+ * the clusters
+ * @param clusterClassifier
+ * used to classify the vectors into different clusters
+ * @param output
+ * the path to store classified data
+ * @param clusterClassificationThreshold
+ * @throws IOException
+ */
+ private static void selectCluster(Path input, List<Cluster> clusterModels,
ClusterClassifier clusterClassifier,
+ Path output, Double clusterClassificationThreshold) throws IOException {
+ Configuration conf = new Configuration();
+ SequenceFile.Writer writer = new
SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(output,
+ "part-m-" + 0), IntWritable.class, VectorWritable.class);
+ for (VectorWritable vw : new
SequenceFileDirValueIterable<VectorWritable>(input, PathType.LIST,
+ PathFilters.logsCRCFilter(), conf)) {
+ Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+ if (shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
+ int maxValueIndex = pdfPerCluster.maxValueIndex();
+ Cluster cluster = clusterModels.get(maxValueIndex);
+ writer.append(new IntWritable(cluster.getId()), vw);
+ }
}
-
- private static void classifyClusterMR(Configuration conf, Path input,
Path clustersIn, Path output, Double clusterClassificationThreshold) throws
IOException,
-
InterruptedException,
-
ClassNotFoundException {
- Job job = new Job(conf, "Cluster Classification Driver running over
input: " + input);
- job.setJarByClass(ClusterClassificationDriver.class);
-
- conf.setFloat(OUTLIER_REMOVAL_THRESHOLD,
clusterClassificationThreshold.floatValue());
-
- conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN,
input.toString());
-
- job.setInputFormatClass(SequenceFileInputFormat.class);
- job.setOutputFormatClass(SequenceFileOutputFormat.class);
-
- job.setMapperClass(ClusterClassificationMapper.class);
- job.setNumReduceTasks(0);
-
- job.setOutputKeyClass(IntWritable.class);
- job.setOutputValueClass(WeightedVectorWritable.class);
-
- FileInputFormat.addInputPath(job, input);
- FileOutputFormat.setOutputPath(job, output);
- if (!job.waitForCompletion(true)) {
- throw new InterruptedException("Cluster Classification Driver Job
failed processing " + input);
- }
- }
-
- }
+ writer.close();
+ }
+
+ /**
+ * Decides whether the vector should be classified or not based on the max
pdf
+ * value of the clusters and threshold value.
+ *
+ * @param pdfPerCluster
+ * pdf of vector belonging to different clusters.
+ * @param clusterClassificationThreshold
+ * threshold below which the vectors won't be classified.
+ * @return whether the vector should be classified or not.
+ */
+ private static boolean shouldClassify(Vector pdfPerCluster, Double
clusterClassificationThreshold) {
+ return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+ }
+
+ private static void classifyClusterMR(Configuration conf, Path input, Path
clustersIn, Path output,
+ Double clusterClassificationThreshold) throws IOException,
InterruptedException, ClassNotFoundException {
+ Job job = new Job(conf, "Cluster Classification Driver running over input:
" + input);
+ job.setJarByClass(ClusterClassificationDriver.class);
+
+ conf.setFloat(OUTLIER_REMOVAL_THRESHOLD,
clusterClassificationThreshold.floatValue());
+
+ conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, input.toString());
+
+ job.setInputFormatClass(SequenceFileInputFormat.class);
+ job.setOutputFormatClass(SequenceFileOutputFormat.class);
+
+ job.setMapperClass(ClusterClassificationMapper.class);
+ job.setNumReduceTasks(0);
+
+ job.setOutputKeyClass(IntWritable.class);
+ job.setOutputValueClass(WeightedVectorWritable.class);
+
+ FileInputFormat.addInputPath(job, input);
+ FileOutputFormat.setOutputPath(job, output);
+ if (!job.waitForCompletion(true)) {
+ throw new InterruptedException("Cluster Classification Driver Job failed
processing " + input);
+ }
+ }
+
+}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
Thu Feb 23 02:48:03 2012
@@ -33,6 +33,7 @@ import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Mapper;
import org.apache.mahout.clustering.Cluster;
import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.ClusteringPolicy;
import org.apache.mahout.clustering.WeightedVectorWritable;
import org.apache.mahout.common.iterator.sequencefile.PathFilters;
import org.apache.mahout.common.iterator.sequencefile.PathType;
@@ -43,40 +44,39 @@ import org.apache.mahout.math.VectorWrit
/**
* Mapper for classifying vectors into clusters.
*/
-public class ClusterClassificationMapper extends
- Mapper<IntWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
+public class ClusterClassificationMapper extends
Mapper<IntWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
private static double threshold;
private List<Cluster> clusterModels;
private ClusterClassifier clusterClassifier;
private IntWritable clusterId;
private WeightedVectorWritable weightedVW;
-
+
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
- super.setup(context);
-
- Configuration conf = context.getConfiguration();
- String clustersIn =
conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
-
- clusterModels = new ArrayList<Cluster>();
-
- if (clustersIn != null && !clustersIn.isEmpty()) {
- Path clustersInPath = new Path(clustersIn, "*");
- populateClusterModels(clustersInPath);
- clusterClassifier = new ClusterClassifier(clusterModels, null);
- }
- threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
- clusterId = new IntWritable();
- weightedVW = new WeightedVectorWritable(1, null);
+ super.setup(context);
+
+ Configuration conf = context.getConfiguration();
+ String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
+
+ clusterModels = new ArrayList<Cluster>();
+
+ if (clustersIn != null && !clustersIn.isEmpty()) {
+ Path clustersInPath = new Path(clustersIn, "*");
+ populateClusterModels(clustersInPath);
+ ClusteringPolicy policy = ClusterClassifier.readPolicy(clustersInPath);
+ clusterClassifier = new ClusterClassifier(clusterModels, policy);
}
+ threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
+ clusterId = new IntWritable();
+ weightedVW = new WeightedVectorWritable(1, null);
+ }
@Override
- protected void map(IntWritable key, VectorWritable vw, Context context)
throws IOException,
-
InterruptedException {
- if(!clusterModels.isEmpty()) {
+ protected void map(IntWritable key, VectorWritable vw, Context context)
throws IOException, InterruptedException {
+ if (!clusterModels.isEmpty()) {
Vector pdfPerCluster = clusterClassifier.classify(vw.get());
- if(shouldClassify(pdfPerCluster)) {
+ if (shouldClassify(pdfPerCluster)) {
int maxValueIndex = pdfPerCluster.maxValueIndex();
Cluster cluster = clusterModels.get(maxValueIndex);
clusterId.set(cluster.getId());
@@ -92,12 +92,8 @@ public class ClusterClassificationMapper
Cluster cluster = null;
FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath,
PathFilters.finalPartFilter());
- Iterator<?> it = new
SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
- PathType.LIST,
-
PathFilters.partFilter(),
- null,
- false,
- conf);
+ Iterator<?> it = new
SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(), PathType.LIST,
+ PathFilters.partFilter(), null, false, conf);
while (it.hasNext()) {
cluster = (Cluster) it.next();
clusterModels.add(cluster);
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
Thu Feb 23 02:48:03 2012
@@ -32,6 +32,8 @@ import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.CanopyClusteringPolicy;
+import org.apache.mahout.clustering.ClusterClassifier;
import org.apache.mahout.clustering.ClusteringTestUtils;
import org.apache.mahout.clustering.canopy.CanopyDriver;
import org.apache.mahout.common.MahoutTestCase;
@@ -44,7 +46,7 @@ import org.junit.Test;
import com.google.common.collect.Lists;
-public class ClusterClassificationDriverTest extends MahoutTestCase{
+public class ClusterClassificationDriverTest extends MahoutTestCase {
private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4,
4}, {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
@@ -53,11 +55,11 @@ public class ClusterClassificationDriver
private Path clusteringOutputPath;
private Configuration conf;
-
+
private Path pointsPath;
-
+
private Path classifiedOutputPath;
-
+
private List<Vector> firstCluster;
private List<Vector> secondCluster;
@@ -93,7 +95,7 @@ public class ClusterClassificationDriver
pointsPath = getTestTempDirPath("points");
clusteringOutputPath = getTestTempDirPath("output");
classifiedOutputPath = getTestTempDirPath("classify");
-
+
conf = new Configuration();
ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath,
"file1"), fs, conf);
@@ -110,7 +112,7 @@ public class ClusterClassificationDriver
pointsPath = getTestTempDirPath("points");
clusteringOutputPath = getTestTempDirPath("output");
classifiedOutputPath = getTestTempDirPath("classify");
-
+
conf = new Configuration();
ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath,
"file1"), fs, conf);
@@ -120,20 +122,23 @@ public class ClusterClassificationDriver
assertVectorsWithOutlierRemoval();
}
- private void runClustering(Path pointsPath, Configuration conf) throws
IOException,
- InterruptedException,
- ClassNotFoundException {
+ private void runClustering(Path pointsPath, Configuration conf) throws
IOException, InterruptedException,
+ ClassNotFoundException {
CanopyDriver.run(conf, pointsPath, clusteringOutputPath, new
ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
+ Path finalClustersPath = new Path(clusteringOutputPath,
"clusters-0-final");
+ ClusterClassifier.writePolicy(new CanopyClusteringPolicy(),
finalClustersPath);
}
- private void runClassificationWithoutOutlierRemoval(Configuration conf)
throws IOException, InterruptedException, ClassNotFoundException {
+ private void runClassificationWithoutOutlierRemoval(Configuration conf)
throws IOException, InterruptedException,
+ ClassNotFoundException {
ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
classifiedOutputPath, 0.0, true);
}
- private void runClassificationWithOutlierRemoval(Configuration conf2) throws
IOException, InterruptedException, ClassNotFoundException {
+ private void runClassificationWithOutlierRemoval(Configuration conf2) throws
IOException, InterruptedException,
+ ClassNotFoundException {
ClusterClassificationDriver.run(pointsPath, clusteringOutputPath,
classifiedOutputPath, 0.73, true);
}
-
+
private void collectVectorsForAssertion() throws IOException {
Path[] partFilePaths =
FileUtil.stat2Paths(fs.globStatus(classifiedOutputPath));
FileStatus[] listStatus = fs.listStatus(partFilePaths);
@@ -148,13 +153,11 @@ public class ClusterClassificationDriver
}
private void collectVector(String clusterId, Vector vector) {
- if(clusterId.equals("0")) {
+ if (clusterId.equals("0")) {
firstCluster.add(vector);
- }
- else if(clusterId.equals("1")) {
+ } else if (clusterId.equals("1")) {
secondCluster.add(vector);
- }
- else if(clusterId.equals("2")) {
+ } else if (clusterId.equals("2")) {
thirdCluster.add(vector);
}
}
@@ -164,53 +167,52 @@ public class ClusterClassificationDriver
assertSecondClusterWithOutlierRemoval();
assertThirdClusterWithOutlierRemoval();
}
-
+
private void assertVectorsWithoutOutlierRemoval() {
assertFirstClusterWithoutOutlierRemoval();
assertSecondClusterWithoutOutlierRemoval();
assertThirdClusterWithoutOutlierRemoval();
}
-
+
private void assertThirdClusterWithoutOutlierRemoval() {
Assert.assertEquals(2, thirdCluster.size());
for (Vector vector : thirdCluster) {
Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}",
"{1:8.0,0:8.0}"}, vector.asFormatString()));
}
}
-
+
private void assertSecondClusterWithoutOutlierRemoval() {
Assert.assertEquals(4, secondCluster.size());
for (Vector vector : secondCluster) {
- Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}",
"{1:4.0,0:5.0}", "{1:5.0,0:4.0}",
- "{1:5.0,0:5.0}"}, vector.asFormatString()));
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}",
"{1:4.0,0:5.0}", "{1:5.0,0:4.0}",
+ "{1:5.0,0:5.0}"}, vector.asFormatString()));
}
}
-
+
private void assertFirstClusterWithoutOutlierRemoval() {
Assert.assertEquals(3, firstCluster.size());
for (Vector vector : firstCluster) {
- Assert.assertTrue(ArrayUtils.contains(new String[]
{"{1:1.0,0:1.0}","{1:1.0,0:2.0}", "{1:2.0,0:1.0}"}, vector.asFormatString()));
+ Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}",
"{1:1.0,0:2.0}", "{1:2.0,0:1.0}"},
+ vector.asFormatString()));
}
}
-
private void assertThirdClusterWithOutlierRemoval() {
Assert.assertEquals(1, thirdCluster.size());
for (Vector vector : thirdCluster) {
Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}"},
vector.asFormatString()));
}
}
-
+
private void assertSecondClusterWithOutlierRemoval() {
Assert.assertEquals(0, secondCluster.size());
}
-
+
private void assertFirstClusterWithOutlierRemoval() {
Assert.assertEquals(1, firstCluster.size());
for (Vector vector : firstCluster) {
Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}"},
vector.asFormatString()));
}
}
-
}