Author: jeastman
Date: Thu Feb 23 02:48:03 2012
New Revision: 1292629

URL: http://svn.apache.org/viewvc?rev=1292629&view=rev
Log:
MAHOUT-933: Fixed undetected defects introduced by earlier commit.
I will run all the unit tests before every check-in
I will run all the unit tests before every check-in
I will run all the unit tests before every check-in
...

Modified:
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
    
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
    
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterClassifier.java
 Thu Feb 23 02:48:03 2012
@@ -170,7 +170,7 @@ public class ClusterClassifier extends A
   }
   
   public void writeToSeqFiles(Path path) throws IOException {
-    writePolicy(path);
+    writePolicy(policy, path);
     Configuration config = new Configuration();
     FileSystem fs = FileSystem.get(path.toUri(), config);
     SequenceFile.Writer writer = null;
@@ -202,7 +202,7 @@ public class ClusterClassifier extends A
     this.policy = readPolicy(path);
   }
   
-  private ClusteringPolicy readPolicy(Path path) throws IOException {
+  public static ClusteringPolicy readPolicy(Path path) throws IOException {
     Path policyPath = new Path(path, "_policy");
     Configuration config = new Configuration();
     FileSystem fs = FileSystem.get(policyPath.toUri(), config);
@@ -213,7 +213,7 @@ public class ClusterClassifier extends A
     return cpw.getValue();
   }
   
-  protected void writePolicy(Path path) throws IOException {
+  public static void writePolicy(ClusteringPolicy policy, Path path) throws 
IOException {
     Path policyPath = new Path(path, "_policy");
     Configuration config = new Configuration();
     FileSystem fs = FileSystem.get(policyPath.toUri(), config);

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
 Thu Feb 23 02:48:03 2012
@@ -146,7 +146,6 @@ public class ClusterIterator {
       InterruptedException, ClassNotFoundException {
     Configuration conf = new Configuration();
     HadoopUtil.delete(conf, outPath);
-    ClusterClassifier classifier = new ClusterClassifier(policy);
     for (int iteration = 1; iteration <= numIterations; iteration++) {
       conf.set(PRIOR_PATH_KEY, priorPath.toString());
       
@@ -172,7 +171,7 @@ public class ClusterIterator {
       if (!job.waitForCompletion(true)) {
         throw new InterruptedException("Cluster Iteration " + iteration + " 
failed processing " + priorPath);
       }
-      classifier.writePolicy(clustersOut);
+      ClusterClassifier.writePolicy(policy, clustersOut);
       FileSystem fs = FileSystem.get(outPath.toUri(), conf);
       if (isConverged(clustersOut, conf, fs)) {
         break;

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/FuzzyKMeansClusteringPolicy.java
 Thu Feb 23 02:48:03 2012
@@ -71,6 +71,20 @@ public class FuzzyKMeansClusteringPolicy
     return probabilities;
   }
   
+  @Override
+  public Vector classify(Vector data, List<Cluster> models) {
+    Collection<SoftCluster> clusters = Lists.newArrayList();
+    List<Double> distances = Lists.newArrayList();
+    for (Cluster model : models) {
+      SoftCluster sc = (SoftCluster) model;
+      clusters.add(sc);
+      distances.add(sc.getMeasure().distance(data, sc.getCenter()));
+    }
+    FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer();
+    fuzzyKMeansClusterer.setM(m);
+    return fuzzyKMeansClusterer.computePi(clusters, distances);
+  }
+
   /*
    * (non-Javadoc)
    * 
@@ -93,18 +107,4 @@ public class FuzzyKMeansClusteringPolicy
     this.convergenceDelta = in.readDouble();
   }
   
-  @Override
-  public Vector classify(Vector data, List<Cluster> models) {
-    Collection<SoftCluster> clusters = Lists.newArrayList();
-    List<Double> distances = Lists.newArrayList();
-    for (Cluster model : models) {
-      SoftCluster sc = (SoftCluster) model;
-      clusters.add(sc);
-      distances.add(sc.getMeasure().distance(data, sc.getCenter()));
-    }
-    FuzzyKMeansClusterer fuzzyKMeansClusterer = new FuzzyKMeansClusterer();
-    fuzzyKMeansClusterer.setM(m);
-    return fuzzyKMeansClusterer.computePi(clusters, distances);
-  }
-  
 }

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationDriver.java
 Thu Feb 23 02:48:03 2012
@@ -39,6 +39,7 @@ import org.apache.hadoop.mapreduce.lib.o
 import org.apache.hadoop.util.ToolRunner;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.ClusteringPolicy;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.AbstractJob;
 import org.apache.mahout.common.commandline.DefaultOptionCreator;
@@ -50,189 +51,194 @@ import org.apache.mahout.math.Vector;
 import org.apache.mahout.math.VectorWritable;
 
 /**
- * Classifies the vectors into different clusters found by the clustering 
algorithm.
+ * Classifies the vectors into different clusters found by the clustering
+ * algorithm.
  */
 public class ClusterClassificationDriver extends AbstractJob {
-         
-    /**
-          * CLI to run Cluster Classification Driver.
-          */
-         @Override
-         public int run(String[] args) throws Exception {
-           
-           addInputOption();
-           addOutputOption();
-           addOption(DefaultOptionCreator.methodOption().create());
-           addOption(DefaultOptionCreator.clustersInOption()
-                   .withDescription("The input centroids, as Vectors.  Must be 
a SequenceFile of Writable, Cluster/Canopy.")
-                   .create());
-
-           if (parseArguments(args) == null) {
-             return -1;
-           }
-           
-           Path input = getInputPath();
-           Path output = getOutputPath();
-
-           if (getConf() == null) {
-             setConf(new Configuration());
-           }
-           Path clustersIn = new 
Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
-           boolean runSequential = 
getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
-             DefaultOptionCreator.SEQUENTIAL_METHOD);
-           
-           double clusterClassificationThreshold = 0.0;
-           if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
-             clusterClassificationThreshold = 
Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
-           }
-           
-      run(input, clustersIn, output, clusterClassificationThreshold , 
runSequential);
-      
-           return 0;
-         }
-         
-         /**
-          * Constructor to be used by the ToolRunner.
-          */
-         private ClusterClassificationDriver() {}
-         
-         public static void main(String[] args) throws Exception {
-           ToolRunner.run(new Configuration(), new 
ClusterClassificationDriver(), args);
-         }
-         
-         /**
-          * Uses {@link ClusterClassifier} to classify input vectors into 
their respective clusters.
-          * 
-          * @param input 
-          *         the input vectors
-          * @param clusteringOutputPath
-          *         the output path of clustering ( it reads clusters-*-final 
file from here )
-          * @param output
-          *         the location to store the classified vectors
-          * @param clusterClassificationThreshold
-          *         the threshold value of probability distribution function 
from 0.0 to 1.0. 
-          *         Any vector with pdf less that this threshold will not be 
classified for the cluster.
-          * @param runSequential
-          *         Run the process sequentially or in a mapreduce way.
-          * @throws IOException
-          * @throws InterruptedException
-          * @throws ClassNotFoundException
-          */
-         public static void run(Path input, Path clusteringOutputPath, Path 
output, Double clusterClassificationThreshold, boolean runSequential) throws 
IOException,
-                                                                               
InterruptedException,
-                                                                               
ClassNotFoundException {
-           if (runSequential) {
-             classifyClusterSeq(input, clusteringOutputPath, output, 
clusterClassificationThreshold);
-           } else {
-             Configuration conf = new Configuration();
-             classifyClusterMR(conf, input, clusteringOutputPath, output, 
clusterClassificationThreshold);
-           }
-           
-         }
-         
-         private static void classifyClusterSeq(Path input, Path clusters, 
Path output, Double clusterClassificationThreshold) throws IOException {
-           List<Cluster> clusterModels = populateClusterModels(clusters);
-           ClusterClassifier clusterClassifier = new 
ClusterClassifier(clusterModels, null);
-      selectCluster(input, clusterModels, clusterClassifier, output, 
clusterClassificationThreshold);
-      
-         }
-
-         /**
-          * Populates a list with clusters present in clusters-*-final 
directory.
-          * 
-          * @param clusterOutputPath
-          *             The output path of the clustering.
-          * @return
-          *             The list of clusters found by the clustering.
-          * @throws IOException
-          */
-    private static List<Cluster> populateClusterModels(Path clusterOutputPath) 
throws IOException {
-      List<Cluster> clusterModels = new ArrayList<Cluster>();
+  
+  /**
+   * CLI to run Cluster Classification Driver.
+   */
+  @Override
+  public int run(String[] args) throws Exception {
+    
+    addInputOption();
+    addOutputOption();
+    addOption(DefaultOptionCreator.methodOption().create());
+    addOption(DefaultOptionCreator.clustersInOption()
+        .withDescription("The input centroids, as Vectors.  Must be a 
SequenceFile of Writable, Cluster/Canopy.")
+        .create());
+    
+    if (parseArguments(args) == null) {
+      return -1;
+    }
+    
+    Path input = getInputPath();
+    Path output = getOutputPath();
+    
+    if (getConf() == null) {
+      setConf(new Configuration());
+    }
+    Path clustersIn = new 
Path(getOption(DefaultOptionCreator.CLUSTERS_IN_OPTION));
+    boolean runSequential = 
getOption(DefaultOptionCreator.METHOD_OPTION).equalsIgnoreCase(
+        DefaultOptionCreator.SEQUENTIAL_METHOD);
+    
+    double clusterClassificationThreshold = 0.0;
+    if (hasOption(DefaultOptionCreator.OUTLIER_THRESHOLD)) {
+      clusterClassificationThreshold = 
Double.parseDouble(getOption(DefaultOptionCreator.OUTLIER_THRESHOLD));
+    }
+    
+    run(input, clustersIn, output, clusterClassificationThreshold, 
runSequential);
+    
+    return 0;
+  }
+  
+  /**
+   * Constructor to be used by the ToolRunner.
+   */
+  private ClusterClassificationDriver() {}
+  
+  public static void main(String[] args) throws Exception {
+    ToolRunner.run(new Configuration(), new ClusterClassificationDriver(), 
args);
+  }
+  
+  /**
+   * Uses {@link ClusterClassifier} to classify input vectors into their
+   * respective clusters.
+   * 
+   * @param input
+   *          the input vectors
+   * @param clusteringOutputPath
+   *          the output path of clustering ( it reads clusters-*-final file
+   *          from here )
+   * @param output
+   *          the location to store the classified vectors
+   * @param clusterClassificationThreshold
+   *          the threshold value of probability distribution function from 0.0
+   *          to 1.0. Any vector with pdf less that this threshold will not be
+   *          classified for the cluster.
+   * @param runSequential
+   *          Run the process sequentially or in a mapreduce way.
+   * @throws IOException
+   * @throws InterruptedException
+   * @throws ClassNotFoundException
+   */
+  public static void run(Path input, Path clusteringOutputPath, Path output, 
Double clusterClassificationThreshold,
+      boolean runSequential) throws IOException, InterruptedException, 
ClassNotFoundException {
+    if (runSequential) {
+      classifyClusterSeq(input, clusteringOutputPath, output, 
clusterClassificationThreshold);
+    } else {
       Configuration conf = new Configuration();
-      Cluster cluster = null;
-      FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
-      FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, 
PathFilters.finalPartFilter());
-      Iterator<?> it = new 
SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
-                                                                  
PathType.LIST,
-                                                                  
PathFilters.partFilter(),
-                                                                  null,
-                                                                  false,
-                                                                  conf);
-      while (it.hasNext()) {
-        cluster = (Cluster) it.next();
-        clusterModels.add(cluster);
-      }
-      return clusterModels;
+      classifyClusterMR(conf, input, clusteringOutputPath, output, 
clusterClassificationThreshold);
     }
-         
-    /**
-     * Classifies the vector into its respective cluster.
-     * 
-     * @param input 
-     *            the path containing the input vector.
-     * @param clusterModels
-     *            the clusters
-     * @param clusterClassifier
-     *            used to classify the vectors into different clusters
-     * @param output
-     *            the path to store classified data
-     * @param clusterClassificationThreshold
-     * @throws IOException
-     */
-         private static void selectCluster(Path input, List<Cluster> 
clusterModels, ClusterClassifier clusterClassifier, Path output, Double 
clusterClassificationThreshold) throws IOException {
-           Configuration conf = new Configuration();
-           SequenceFile.Writer writer = new 
SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(
-          output, "part-m-" + 0), IntWritable.class,
-          VectorWritable.class);
-           for (VectorWritable vw : new 
SequenceFileDirValueIterable<VectorWritable>(
-               input, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
-        Vector pdfPerCluster = clusterClassifier.classify(vw.get());
-        if(shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
-          int maxValueIndex = pdfPerCluster.maxValueIndex();
-          Cluster cluster = clusterModels.get(maxValueIndex);
-          writer.append(new IntWritable(cluster.getId()), vw);
-        }
-           }
-           writer.close();
+    
+  }
+  
+  private static void classifyClusterSeq(Path input, Path clusters, Path 
output, Double clusterClassificationThreshold)
+      throws IOException {
+    List<Cluster> clusterModels = populateClusterModels(clusters);
+    ClusteringPolicy policy = 
ClusterClassifier.readPolicy(finalClustersPath(clusters));
+    ClusterClassifier clusterClassifier = new ClusterClassifier(clusterModels, 
policy);
+    selectCluster(input, clusterModels, clusterClassifier, output, 
clusterClassificationThreshold);
+    
+  }
+  
+  /**
+   * Populates a list with clusters present in clusters-*-final directory.
+   * 
+   * @param clusterOutputPath
+   *          The output path of the clustering.
+   * @return The list of clusters found by the clustering.
+   * @throws IOException
+   */
+  private static List<Cluster> populateClusterModels(Path clusterOutputPath) 
throws IOException {
+    List<Cluster> clusterModels = new ArrayList<Cluster>();
+    Cluster cluster = null;
+    Path finalClustersPath = finalClustersPath(clusterOutputPath);
+    Iterator<?> it = new 
SequenceFileDirValueIterator<Writable>(finalClustersPath, PathType.LIST,
+        PathFilters.partFilter(), null, false, new Configuration());
+    while (it.hasNext()) {
+      cluster = (Cluster) it.next();
+      clusterModels.add(cluster);
     }
-
-         /**
-          * Decides whether the vector should be classified or not based on 
the max pdf value of the clusters and threshold value.
-          * 
-          * @param pdfPerCluster
-          *         pdf of vector belonging to different clusters.
-          * @param clusterClassificationThreshold
-          *         threshold below which the vectors won't be classified.
-          * @return whether the vector should be classified or not.
-          */
-    private static boolean shouldClassify(Vector pdfPerCluster, Double 
clusterClassificationThreshold) {
-      return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+    return clusterModels;
+  }
+  
+  private static Path finalClustersPath(Path clusterOutputPath) throws 
IOException {
+    FileSystem fileSystem = clusterOutputPath.getFileSystem(new 
Configuration());
+    FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, 
PathFilters.finalPartFilter());
+    Path finalClustersPath = clusterFiles[0].getPath();
+    return finalClustersPath;
+  }
+  
+  /**
+   * Classifies the vector into its respective cluster.
+   * 
+   * @param input
+   *          the path containing the input vector.
+   * @param clusterModels
+   *          the clusters
+   * @param clusterClassifier
+   *          used to classify the vectors into different clusters
+   * @param output
+   *          the path to store classified data
+   * @param clusterClassificationThreshold
+   * @throws IOException
+   */
+  private static void selectCluster(Path input, List<Cluster> clusterModels, 
ClusterClassifier clusterClassifier,
+      Path output, Double clusterClassificationThreshold) throws IOException {
+    Configuration conf = new Configuration();
+    SequenceFile.Writer writer = new 
SequenceFile.Writer(input.getFileSystem(conf), conf, new Path(output,
+        "part-m-" + 0), IntWritable.class, VectorWritable.class);
+    for (VectorWritable vw : new 
SequenceFileDirValueIterable<VectorWritable>(input, PathType.LIST,
+        PathFilters.logsCRCFilter(), conf)) {
+      Vector pdfPerCluster = clusterClassifier.classify(vw.get());
+      if (shouldClassify(pdfPerCluster, clusterClassificationThreshold)) {
+        int maxValueIndex = pdfPerCluster.maxValueIndex();
+        Cluster cluster = clusterModels.get(maxValueIndex);
+        writer.append(new IntWritable(cluster.getId()), vw);
+      }
     }
-
-         private static void classifyClusterMR(Configuration conf, Path input, 
Path clustersIn, Path output, Double clusterClassificationThreshold) throws 
IOException,
-                                                                               
        InterruptedException,
-                                                                               
        ClassNotFoundException {
-           Job job = new Job(conf, "Cluster Classification Driver running over 
input: " + input);
-           job.setJarByClass(ClusterClassificationDriver.class);
-           
-           conf.setFloat(OUTLIER_REMOVAL_THRESHOLD, 
clusterClassificationThreshold.floatValue());
-           
-           conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, 
input.toString());
-           
-           job.setInputFormatClass(SequenceFileInputFormat.class);
-           job.setOutputFormatClass(SequenceFileOutputFormat.class);
-           
-           job.setMapperClass(ClusterClassificationMapper.class);
-           job.setNumReduceTasks(0);
-           
-           job.setOutputKeyClass(IntWritable.class);
-           job.setOutputValueClass(WeightedVectorWritable.class);
-           
-           FileInputFormat.addInputPath(job, input);
-           FileOutputFormat.setOutputPath(job, output);
-           if (!job.waitForCompletion(true)) {
-             throw new InterruptedException("Cluster Classification Driver Job 
failed processing " + input);
-           }
-         }
-         
-       }
+    writer.close();
+  }
+  
+  /**
+   * Decides whether the vector should be classified or not based on the max 
pdf
+   * value of the clusters and threshold value.
+   * 
+   * @param pdfPerCluster
+   *          pdf of vector belonging to different clusters.
+   * @param clusterClassificationThreshold
+   *          threshold below which the vectors won't be classified.
+   * @return whether the vector should be classified or not.
+   */
+  private static boolean shouldClassify(Vector pdfPerCluster, Double 
clusterClassificationThreshold) {
+    return pdfPerCluster.maxValue() >= clusterClassificationThreshold;
+  }
+  
+  private static void classifyClusterMR(Configuration conf, Path input, Path 
clustersIn, Path output,
+      Double clusterClassificationThreshold) throws IOException, 
InterruptedException, ClassNotFoundException {
+    Job job = new Job(conf, "Cluster Classification Driver running over input: 
" + input);
+    job.setJarByClass(ClusterClassificationDriver.class);
+    
+    conf.setFloat(OUTLIER_REMOVAL_THRESHOLD, 
clusterClassificationThreshold.floatValue());
+    
+    conf.set(ClusterClassificationConfigKeys.CLUSTERS_IN, input.toString());
+    
+    job.setInputFormatClass(SequenceFileInputFormat.class);
+    job.setOutputFormatClass(SequenceFileOutputFormat.class);
+    
+    job.setMapperClass(ClusterClassificationMapper.class);
+    job.setNumReduceTasks(0);
+    
+    job.setOutputKeyClass(IntWritable.class);
+    job.setOutputValueClass(WeightedVectorWritable.class);
+    
+    FileInputFormat.addInputPath(job, input);
+    FileOutputFormat.setOutputPath(job, output);
+    if (!job.waitForCompletion(true)) {
+      throw new InterruptedException("Cluster Classification Driver Job failed 
processing " + input);
+    }
+  }
+  
+}

Modified: 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
--- 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
 (original)
+++ 
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/classify/ClusterClassificationMapper.java
 Thu Feb 23 02:48:03 2012
@@ -33,6 +33,7 @@ import org.apache.hadoop.io.Writable;
 import org.apache.hadoop.mapreduce.Mapper;
 import org.apache.mahout.clustering.Cluster;
 import org.apache.mahout.clustering.ClusterClassifier;
+import org.apache.mahout.clustering.ClusteringPolicy;
 import org.apache.mahout.clustering.WeightedVectorWritable;
 import org.apache.mahout.common.iterator.sequencefile.PathFilters;
 import org.apache.mahout.common.iterator.sequencefile.PathType;
@@ -43,40 +44,39 @@ import org.apache.mahout.math.VectorWrit
 /**
  * Mapper for classifying vectors into clusters.
  */
-public class ClusterClassificationMapper extends
-    Mapper<IntWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
+public class ClusterClassificationMapper extends 
Mapper<IntWritable,VectorWritable,IntWritable,WeightedVectorWritable> {
   
   private static double threshold;
   private List<Cluster> clusterModels;
   private ClusterClassifier clusterClassifier;
   private IntWritable clusterId;
   private WeightedVectorWritable weightedVW;
-
+  
   @Override
   protected void setup(Context context) throws IOException, 
InterruptedException {
-      super.setup(context);
-
-      Configuration conf = context.getConfiguration();
-      String clustersIn = 
conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
-      
-      clusterModels = new ArrayList<Cluster>();
-      
-      if (clustersIn != null && !clustersIn.isEmpty()) {
-        Path clustersInPath = new Path(clustersIn, "*");
-        populateClusterModels(clustersInPath);
-        clusterClassifier = new ClusterClassifier(clusterModels, null);
-      }
-      threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
-      clusterId = new IntWritable();
-      weightedVW = new WeightedVectorWritable(1, null);
+    super.setup(context);
+    
+    Configuration conf = context.getConfiguration();
+    String clustersIn = conf.get(ClusterClassificationConfigKeys.CLUSTERS_IN);
+    
+    clusterModels = new ArrayList<Cluster>();
+    
+    if (clustersIn != null && !clustersIn.isEmpty()) {
+      Path clustersInPath = new Path(clustersIn, "*");
+      populateClusterModels(clustersInPath);
+      ClusteringPolicy policy = ClusterClassifier.readPolicy(clustersInPath);
+      clusterClassifier = new ClusterClassifier(clusterModels, policy);
     }
+    threshold = conf.getFloat(OUTLIER_REMOVAL_THRESHOLD, 0.0f);
+    clusterId = new IntWritable();
+    weightedVW = new WeightedVectorWritable(1, null);
+  }
   
   @Override
-  protected void map(IntWritable key, VectorWritable vw, Context context) 
throws IOException,
-                                                                               
      InterruptedException {
-    if(!clusterModels.isEmpty()) {
+  protected void map(IntWritable key, VectorWritable vw, Context context) 
throws IOException, InterruptedException {
+    if (!clusterModels.isEmpty()) {
       Vector pdfPerCluster = clusterClassifier.classify(vw.get());
-      if(shouldClassify(pdfPerCluster)) {
+      if (shouldClassify(pdfPerCluster)) {
         int maxValueIndex = pdfPerCluster.maxValueIndex();
         Cluster cluster = clusterModels.get(maxValueIndex);
         clusterId.set(cluster.getId());
@@ -92,12 +92,8 @@ public class ClusterClassificationMapper
     Cluster cluster = null;
     FileSystem fileSystem = clusterOutputPath.getFileSystem(conf);
     FileStatus[] clusterFiles = fileSystem.listStatus(clusterOutputPath, 
PathFilters.finalPartFilter());
-    Iterator<?> it = new 
SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(),
-                                                                PathType.LIST,
-                                                                
PathFilters.partFilter(),
-                                                                null,
-                                                                false,
-                                                                conf);
+    Iterator<?> it = new 
SequenceFileDirValueIterator<Writable>(clusterFiles[0].getPath(), PathType.LIST,
+        PathFilters.partFilter(), null, false, conf);
     while (it.hasNext()) {
       cluster = (Cluster) it.next();
       clusterModels.add(cluster);

Modified: 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
URL: 
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java?rev=1292629&r1=1292628&r2=1292629&view=diff
==============================================================================
--- 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
 (original)
+++ 
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/classify/ClusterClassificationDriverTest.java
 Thu Feb 23 02:48:03 2012
@@ -32,6 +32,8 @@ import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.IntWritable;
 import org.apache.hadoop.io.SequenceFile;
 import org.apache.hadoop.io.Writable;
+import org.apache.mahout.clustering.CanopyClusteringPolicy;
+import org.apache.mahout.clustering.ClusterClassifier;
 import org.apache.mahout.clustering.ClusteringTestUtils;
 import org.apache.mahout.clustering.canopy.CanopyDriver;
 import org.apache.mahout.common.MahoutTestCase;
@@ -44,7 +46,7 @@ import org.junit.Test;
 
 import com.google.common.collect.Lists;
 
-public class ClusterClassificationDriverTest extends MahoutTestCase{
+public class ClusterClassificationDriverTest extends MahoutTestCase {
   
   private static final double[][] REFERENCE = { {1, 1}, {2, 1}, {1, 2}, {4, 
4}, {5, 4}, {4, 5}, {5, 5}, {9, 9}, {8, 8}};
   
@@ -53,11 +55,11 @@ public class ClusterClassificationDriver
   private Path clusteringOutputPath;
   
   private Configuration conf;
-
+  
   private Path pointsPath;
-
+  
   private Path classifiedOutputPath;
-
+  
   private List<Vector> firstCluster;
   
   private List<Vector> secondCluster;
@@ -93,7 +95,7 @@ public class ClusterClassificationDriver
     pointsPath = getTestTempDirPath("points");
     clusteringOutputPath = getTestTempDirPath("output");
     classifiedOutputPath = getTestTempDirPath("classify");
-
+    
     conf = new Configuration();
     
     ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, 
"file1"), fs, conf);
@@ -110,7 +112,7 @@ public class ClusterClassificationDriver
     pointsPath = getTestTempDirPath("points");
     clusteringOutputPath = getTestTempDirPath("output");
     classifiedOutputPath = getTestTempDirPath("classify");
-
+    
     conf = new Configuration();
     
     ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath, 
"file1"), fs, conf);
@@ -120,20 +122,23 @@ public class ClusterClassificationDriver
     assertVectorsWithOutlierRemoval();
   }
   
-  private void runClustering(Path pointsPath, Configuration conf) throws 
IOException,
-  InterruptedException,
-  ClassNotFoundException {
+  private void runClustering(Path pointsPath, Configuration conf) throws 
IOException, InterruptedException,
+      ClassNotFoundException {
     CanopyDriver.run(conf, pointsPath, clusteringOutputPath, new 
ManhattanDistanceMeasure(), 3.1, 2.1, false, true);
+    Path finalClustersPath = new Path(clusteringOutputPath, 
"clusters-0-final");
+    ClusterClassifier.writePolicy(new CanopyClusteringPolicy(), 
finalClustersPath);
   }
   
-  private void runClassificationWithoutOutlierRemoval(Configuration conf) 
throws IOException, InterruptedException, ClassNotFoundException {
+  private void runClassificationWithoutOutlierRemoval(Configuration conf) 
throws IOException, InterruptedException,
+      ClassNotFoundException {
     ClusterClassificationDriver.run(pointsPath, clusteringOutputPath, 
classifiedOutputPath, 0.0, true);
   }
   
-  private void runClassificationWithOutlierRemoval(Configuration conf2) throws 
IOException, InterruptedException, ClassNotFoundException {
+  private void runClassificationWithOutlierRemoval(Configuration conf2) throws 
IOException, InterruptedException,
+      ClassNotFoundException {
     ClusterClassificationDriver.run(pointsPath, clusteringOutputPath, 
classifiedOutputPath, 0.73, true);
   }
-
+  
   private void collectVectorsForAssertion() throws IOException {
     Path[] partFilePaths = 
FileUtil.stat2Paths(fs.globStatus(classifiedOutputPath));
     FileStatus[] listStatus = fs.listStatus(partFilePaths);
@@ -148,13 +153,11 @@ public class ClusterClassificationDriver
   }
   
   private void collectVector(String clusterId, Vector vector) {
-    if(clusterId.equals("0")) {
+    if (clusterId.equals("0")) {
       firstCluster.add(vector);
-    }
-    else if(clusterId.equals("1")) {
+    } else if (clusterId.equals("1")) {
       secondCluster.add(vector);
-    }
-    else if(clusterId.equals("2")) {
+    } else if (clusterId.equals("2")) {
       thirdCluster.add(vector);
     }
   }
@@ -164,53 +167,52 @@ public class ClusterClassificationDriver
     assertSecondClusterWithOutlierRemoval();
     assertThirdClusterWithOutlierRemoval();
   }
-
+  
   private void assertVectorsWithoutOutlierRemoval() {
     assertFirstClusterWithoutOutlierRemoval();
     assertSecondClusterWithoutOutlierRemoval();
     assertThirdClusterWithoutOutlierRemoval();
   }
-
+  
   private void assertThirdClusterWithoutOutlierRemoval() {
     Assert.assertEquals(2, thirdCluster.size());
     for (Vector vector : thirdCluster) {
       Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}", 
"{1:8.0,0:8.0}"}, vector.asFormatString()));
     }
   }
-
+  
   private void assertSecondClusterWithoutOutlierRemoval() {
     Assert.assertEquals(4, secondCluster.size());
     for (Vector vector : secondCluster) {
-    Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}", 
"{1:4.0,0:5.0}", "{1:5.0,0:4.0}",
-    "{1:5.0,0:5.0}"}, vector.asFormatString()));
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:4.0,0:4.0}", 
"{1:4.0,0:5.0}", "{1:5.0,0:4.0}",
+          "{1:5.0,0:5.0}"}, vector.asFormatString()));
     }
   }
-
+  
   private void assertFirstClusterWithoutOutlierRemoval() {
     Assert.assertEquals(3, firstCluster.size());
     for (Vector vector : firstCluster) {
-      Assert.assertTrue(ArrayUtils.contains(new String[] 
{"{1:1.0,0:1.0}","{1:1.0,0:2.0}", "{1:2.0,0:1.0}"}, vector.asFormatString()));
+      Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}", 
"{1:1.0,0:2.0}", "{1:2.0,0:1.0}"},
+          vector.asFormatString()));
     }
   }
   
-
   private void assertThirdClusterWithOutlierRemoval() {
     Assert.assertEquals(1, thirdCluster.size());
     for (Vector vector : thirdCluster) {
       Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:9.0,0:9.0}"}, 
vector.asFormatString()));
     }
   }
-
+  
   private void assertSecondClusterWithOutlierRemoval() {
     Assert.assertEquals(0, secondCluster.size());
   }
-
+  
   private void assertFirstClusterWithOutlierRemoval() {
     Assert.assertEquals(1, firstCluster.size());
     for (Vector vector : firstCluster) {
       Assert.assertTrue(ArrayUtils.contains(new String[] {"{1:1.0,0:1.0}"}, 
vector.asFormatString()));
     }
   }
-
   
 }


Reply via email to