Author: jeastman
Date: Sun Feb 12 18:44:23 2012
New Revision: 1243294
URL: http://svn.apache.org/viewvc?rev=1243294&view=rev
Log:
MAHOUT-933: Implemented ClusterWritable to support an MR version of
ClusterIterator. Not working correctly yet - needs to incorporate arbitrary
policies - but is a step forward. All tests run.
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterWritable.java
(with props)
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIReducer.java
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java?rev=1243294&r1=1243293&r2=1243294&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIMapper.java
Sun Feb 12 18:44:23 2012
@@ -4,6 +4,7 @@ import java.io.IOException;
import java.util.Iterator;
import java.util.List;
+import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.WritableComparable;
import org.apache.hadoop.mapreduce.Mapper;
@@ -12,36 +13,34 @@ import org.apache.mahout.math.Vector.Ele
import org.apache.mahout.math.VectorWritable;
public class CIMapper extends
- Mapper<WritableComparable<?>,VectorWritable,IntWritable,Cluster> {
+ Mapper<WritableComparable<?>,VectorWritable,IntWritable,ClusterWritable> {
private ClusterClassifier classifier;
+
private ClusteringPolicy policy;
-
+
/*
* (non-Javadoc)
*
- * @see
- *
org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper
- * .Context)
+ * @see
org.apache.hadoop.mapreduce.Mapper#setup(org.apache.hadoop.mapreduce.Mapper
.Context)
*/
@Override
- protected void setup(Context context) throws IOException,
- InterruptedException {
- List<Cluster> models = null;
- classifier = new ClusterClassifier(models);
+ protected void setup(Context context) throws IOException,
InterruptedException {
+ String priorClustersPath =
context.getConfiguration().get(ClusterIterator.PRIOR_PATH_KEY);
+ classifier = ClusterIterator.readClassifier(new Path(priorClustersPath));
policy = new KMeansClusteringPolicy();
super.setup(context);
}
-
+
/*
* (non-Javadoc)
*
- * @see org.apache.hadoop.mapreduce.Mapper#map(java.lang.Object,
- * java.lang.Object, org.apache.hadoop.mapreduce.Mapper.Context)
+ * @see org.apache.hadoop.mapreduce.Mapper#map(java.lang.Object,
java.lang.Object,
+ * org.apache.hadoop.mapreduce.Mapper.Context)
*/
@Override
- protected void map(WritableComparable<?> key, VectorWritable value,
- Context context) throws IOException, InterruptedException {
+ protected void map(WritableComparable<?> key, VectorWritable value, Context
context) throws IOException,
+ InterruptedException {
Vector probabilities = classifier.classify(value.get());
Vector selections = policy.select(probabilities);
for (Iterator<Element> it = selections.iterateNonZero(); it.hasNext();) {
@@ -49,22 +48,21 @@ public class CIMapper extends
classifier.train(el.index(), value.get(), el.get());
}
}
-
+
/*
* (non-Javadoc)
*
- * @see
- * org.apache.hadoop.mapreduce.Mapper#cleanup(org.apache.hadoop.mapreduce.
- * Mapper.Context)
+ * @see
org.apache.hadoop.mapreduce.Mapper#cleanup(org.apache.hadoop.mapreduce.
Mapper.Context)
*/
@Override
- protected void cleanup(Context context) throws IOException,
- InterruptedException {
+ protected void cleanup(Context context) throws IOException,
InterruptedException {
List<Cluster> clusters = classifier.getModels();
+ ClusterWritable cw = new ClusterWritable();
for (int index = 0; index < clusters.size(); index++) {
- context.write(new IntWritable(index), clusters.get(index));
+ cw.setValue(clusters.get(index));
+ context.write(new IntWritable(index), cw);
}
super.cleanup(context);
}
-
+
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIReducer.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIReducer.java?rev=1243294&r1=1243293&r2=1243294&view=diff
==============================================================================
--- mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIReducer.java
(original)
+++ mahout/trunk/core/src/main/java/org/apache/mahout/clustering/CIReducer.java
Sun Feb 12 18:44:23 2012
@@ -23,22 +23,22 @@ import java.util.Iterator;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.mapreduce.Reducer;
-public class CIReducer extends
Reducer<IntWritable,Cluster,IntWritable,Cluster> {
+public class CIReducer extends
Reducer<IntWritable,ClusterWritable,IntWritable,ClusterWritable> {
@Override
- protected void reduce(IntWritable key, Iterable<Cluster> values,
+ protected void reduce(IntWritable key, Iterable<ClusterWritable> values,
Context context) throws IOException, InterruptedException {
- Iterator<Cluster> iter =values.iterator();
- Cluster first = null;
+ Iterator<ClusterWritable> iter =values.iterator();
+ ClusterWritable first = null;
while(iter.hasNext()){
- Cluster cl = iter.next();
+ ClusterWritable cw = iter.next();
if (first == null){
- first = cl;
+ first = cw;
} else {
- first.observe(cl);
+ first.getValue().observe(cw.getValue());
}
}
- first.computeParameters();
+ first.getValue().computeParameters();
context.write(key, first);
}
Modified:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java?rev=1243294&r1=1243293&r2=1243294&view=diff
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
(original)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterIterator.java
Sun Feb 12 18:44:23 2012
@@ -18,6 +18,8 @@ package org.apache.mahout.clustering;
import java.io.IOException;
import java.util.Iterator;
+import java.util.List;
+import java.util.Locale;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileStatus;
@@ -25,7 +27,6 @@ import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.io.IntWritable;
import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat;
@@ -40,37 +41,33 @@ import org.apache.mahout.common.iterator
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
+import com.google.common.collect.Lists;
import com.google.common.io.Closeables;
/**
- * This is an experimental clustering iterator which works with a
- * ClusteringPolicy and a prior ClusterClassifier which has been initialized
- * with a set of models. To date, it has been tested with k-means and Dirichlet
- * clustering. See examples DisplayKMeans and DisplayDirichlet which have been
- * switched over to use it.
+ * This is an experimental clustering iterator which works with a
ClusteringPolicy and a prior ClusterClassifier which
+ * has been initialized with a set of models. To date, it has been tested with
k-means and Dirichlet clustering. See
+ * examples DisplayKMeans and DisplayDirichlet which have been switched over
to use it.
*/
public class ClusterIterator {
-
+
+ public static final String PRIOR_PATH_KEY =
"org.apache.mahout.clustering.prior.path";
+
public ClusterIterator(ClusteringPolicy policy) {
this.policy = policy;
}
-
+
private final ClusteringPolicy policy;
-
+
/**
- * Iterate over data using a prior-trained ClusterClassifier, for a number of
- * iterations
+ * Iterate over data using a prior-trained ClusterClassifier, for a number
of iterations
*
- * @param data
- * a {@code List<Vector>} of input vectors
- * @param classifier
- * a prior ClusterClassifier
- * @param numIterations
- * the int number of iterations to perform
+ * @param data a {@code List<Vector>} of input vectors
+ * @param classifier a prior ClusterClassifier
+ * @param numIterations the int number of iterations to perform
* @return the posterior ClusterClassifier
*/
- public ClusterClassifier iterate(Iterable<Vector> data,
- ClusterClassifier classifier, int numIterations) {
+ public ClusterClassifier iterate(Iterable<Vector> data, ClusterClassifier
classifier, int numIterations) {
for (int iteration = 1; iteration <= numIterations; iteration++) {
for (Vector vector : data) {
// classification yields probabilities
@@ -78,8 +75,7 @@ public class ClusterIterator {
// policy selects weights for models given those probabilities
Vector weights = policy.select(probabilities);
// training causes all models to observe data
- for (Iterator<Vector.Element> it = weights.iterateNonZero(); it
- .hasNext();) {
+ for (Iterator<Vector.Element> it = weights.iterateNonZero();
it.hasNext();) {
int index = it.next().index();
classifier.train(index, vector, weights.get(index));
}
@@ -91,36 +87,30 @@ public class ClusterIterator {
}
return classifier;
}
-
+
/**
- * Iterate over data using a prior-trained ClusterClassifier, for a number of
- * iterations using a sequential implementation
+ * Iterate over data using a prior-trained ClusterClassifier, for a number
of iterations using a sequential
+ * implementation
*
- * @param inPath
- * a Path to input VectorWritables
- * @param priorPath
- * a Path to the prior classifier
- * @param outPath
- * a Path of output directory
- * @param numIterations
- * the int number of iterations to perform
+ * @param inPath a Path to input VectorWritables
+ * @param priorPath a Path to the prior classifier
+ * @param outPath a Path of output directory
+ * @param numIterations the int number of iterations to perform
* @throws IOException
*/
- public void iterateSeq(Path inPath, Path priorPath, Path outPath,
- int numIterations) throws IOException {
+ public void iterateSeq(Path inPath, Path priorPath, Path outPath, int
numIterations) throws IOException {
ClusterClassifier classifier = readClassifier(priorPath);
Configuration conf = new Configuration();
for (int iteration = 1; iteration <= numIterations; iteration++) {
- for (VectorWritable vw : new
SequenceFileDirValueIterable<VectorWritable>(
- inPath, PathType.LIST, PathFilters.logsCRCFilter(), conf)) {
+ for (VectorWritable vw : new
SequenceFileDirValueIterable<VectorWritable>(inPath, PathType.LIST,
+ PathFilters.logsCRCFilter(), conf)) {
Vector vector = vw.get();
// classification yields probabilities
Vector probabilities = classifier.classify(vector);
// policy selects weights for models given those probabilities
Vector weights = policy.select(probabilities);
// training causes all models to observe data
- for (Iterator<Vector.Element> it = weights.iterateNonZero(); it
- .hasNext();) {
+ for (Iterator<Vector.Element> it = weights.iterateNonZero();
it.hasNext();) {
int index = it.next().index();
classifier.train(index, vector, weights.get(index));
}
@@ -130,77 +120,69 @@ public class ClusterIterator {
// update the policy
policy.update(classifier);
// output the classifier
- writeClassifier(classifier, new Path(outPath, "classifier-" + iteration),
- String.valueOf(iteration));
+ writeClassifier(classifier, new Path(outPath, "classifier-" +
iteration));
}
}
-
+
/**
- * Iterate over data using a prior-trained ClusterClassifier, for a number of
- * iterations using a mapreduce implementation
+ * Iterate over data using a prior-trained ClusterClassifier, for a number
of iterations using a mapreduce
+ * implementation
*
- * @param inPath
- * a Path to input VectorWritables
- * @param priorPath
- * a Path to the prior classifier
- * @param outPath
- * a Path of output directory
- * @param numIterations
- * the int number of iterations to perform
+ * @param inPath a Path to input VectorWritables
+ * @param priorPath a Path to the prior classifier
+ * @param outPath a Path of output directory
+ * @param numIterations the int number of iterations to perform
*/
- public static void iterateMR(Path inPath, Path priorPath, Path outPath,
- int numIterations) throws IOException,
InterruptedException,
- ClassNotFoundException {
+ public void iterateMR(Path inPath, Path priorPath, Path outPath, int
numIterations) throws IOException,
+ InterruptedException, ClassNotFoundException {
Configuration conf = new Configuration();
+ HadoopUtil.delete(conf, outPath);
for (int iteration = 1; iteration <= numIterations; iteration++) {
- conf.set("org.apache.mahout.clustering.prior.path",
priorPath.toString());
-
- Job job = new Job(conf, "Cluster Iterator running iteration " + iteration
- + " over priorPath: " + priorPath);
+ conf.set(PRIOR_PATH_KEY, priorPath.toString());
+
+ String jobName = "Cluster Iterator running iteration " + iteration + "
over priorPath: " + priorPath;
+ System.out.println(jobName);
+ Job job = new Job(conf, jobName);
job.setMapOutputKeyClass(IntWritable.class);
- job.setMapOutputValueClass(Cluster.class);
+ job.setMapOutputValueClass(ClusterWritable.class);
job.setOutputKeyClass(IntWritable.class);
- job.setOutputValueClass(Cluster.class);
-
+ job.setOutputValueClass(ClusterWritable.class);
+
job.setInputFormatClass(SequenceFileInputFormat.class);
job.setOutputFormatClass(SequenceFileOutputFormat.class);
job.setMapperClass(CIMapper.class);
job.setReducerClass(CIReducer.class);
-
+
FileInputFormat.addInputPath(job, inPath);
- FileOutputFormat.setOutputPath(job, outPath);
-
+ Path clustersOut = new Path(outPath, "clusters-" + iteration);
+ priorPath = clustersOut;
+ FileOutputFormat.setOutputPath(job, clustersOut);
+
job.setJarByClass(ClusterIterator.class);
- HadoopUtil.delete(conf, outPath);
if (!job.waitForCompletion(true)) {
- throw new InterruptedException("Cluster Iteration " + iteration
- + " failed processing " + priorPath);
+ throw new InterruptedException("Cluster Iteration " + iteration + "
failed processing " + priorPath);
}
FileSystem fs = FileSystem.get(outPath.toUri(), conf);
- if (isConverged(outPath, conf, fs)) {
+ if (isConverged(clustersOut, conf, fs)) {
break;
}
}
}
-
+
/**
- * Return if all of the Clusters in the parts in the filePath have converged
- * or not
+ * Return if all of the Clusters in the parts in the filePath have converged
or not
*
- * @param filePath
- * the file path to the single file containing the clusters
+ * @param filePath the file path to the single file containing the clusters
* @return true if all Clusters are converged
- * @throws IOException
- * if there was an IO error
+ * @throws IOException if there was an IO error
*/
- private static boolean isConverged(Path filePath, Configuration conf,
FileSystem fs)
- throws IOException {
+ private boolean isConverged(Path filePath, Configuration conf, FileSystem
fs) throws IOException {
for (FileStatus part : fs.listStatus(filePath, PathFilters.partFilter())) {
- SequenceFileValueIterator<Cluster> iterator = new
SequenceFileValueIterator<Cluster>(
+ SequenceFileValueIterator<ClusterWritable> iterator = new
SequenceFileValueIterator<ClusterWritable>(
part.getPath(), true, conf);
while (iterator.hasNext()) {
- Cluster value = iterator.next();
- if (!value.isConverged()) {
+ ClusterWritable value = iterator.next();
+ if (!value.getValue().isConverged()) {
Closeables.closeQuietly(iterator);
return false;
}
@@ -208,33 +190,34 @@ public class ClusterIterator {
}
return true;
}
-
- public static void writeClassifier(ClusterClassifier classifier,
- Path outPath, String k) throws IOException {
+
+ public static void writeClassifier(ClusterClassifier classifier, Path
outPath) throws IOException {
Configuration config = new Configuration();
FileSystem fs = FileSystem.get(outPath.toUri(), config);
- SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, outPath,
- Text.class, ClusterClassifier.class);
- try {
- Writable key = new Text(k);
- writer.append(key, classifier);
- } finally {
- Closeables.closeQuietly(writer);
+ SequenceFile.Writer writer = null;
+ ClusterWritable cw = new ClusterWritable();
+ for (int i = 0; i < classifier.getModels().size(); i++) {
+ try {
+ Cluster cluster = classifier.getModels().get(i);
+ cw.setValue(cluster);
+ writer = new SequenceFile.Writer(fs, config, new Path(outPath, "part-"
+ + String.format(Locale.ENGLISH, "%05d", i)), IntWritable.class,
ClusterWritable.class);
+ Writable key = new IntWritable(i);
+ writer.append(key, cw);
+ } finally {
+ Closeables.closeQuietly(writer);
+ }
}
}
-
- public static ClusterClassifier readClassifier(Path inPath)
- throws IOException {
+
+ public static ClusterClassifier readClassifier(Path inPath) throws
IOException {
Configuration config = new Configuration();
- FileSystem fs = FileSystem.get(inPath.toUri(), config);
- SequenceFile.Reader reader = new SequenceFile.Reader(fs, inPath, config);
- Writable key = new Text();
- ClusterClassifier classifierOut = new ClusterClassifier();
- try {
- reader.next(key, classifierOut);
- } finally {
- Closeables.closeQuietly(reader);
+ List<Cluster> clusters = Lists.newArrayList();
+ for (ClusterWritable cw : new
SequenceFileDirValueIterable<ClusterWritable>(inPath, PathType.LIST,
+ PathFilters.logsCRCFilter(), config)) {
+ clusters.add(cw.getValue());
}
+ ClusterClassifier classifierOut = new ClusterClassifier(clusters);
return classifierOut;
}
}
Added:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterWritable.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterWritable.java?rev=1243294&view=auto
==============================================================================
---
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterWritable.java
(added)
+++
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterWritable.java
Sun Feb 12 18:44:23 2012
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.mahout.clustering;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.hadoop.io.Writable;
+import org.apache.mahout.classifier.sgd.PolymorphicWritable;
+
+public class ClusterWritable implements Writable {
+
+ private Cluster value;
+
+ public Cluster getValue() {
+ return value;
+ }
+
+ public void setValue(Cluster value) {
+ this.value = value;
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ PolymorphicWritable.write(out, value);
+ }
+
+ @Override
+ public void readFields(DataInput in) throws IOException {
+ value = PolymorphicWritable.read(in, Cluster.class);
+ }
+
+}
Propchange:
mahout/trunk/core/src/main/java/org/apache/mahout/clustering/ClusterWritable.java
------------------------------------------------------------------------------
svn:mime-type = text/plain
Modified:
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
URL:
http://svn.apache.org/viewvc/mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java?rev=1243294&r1=1243293&r2=1243294&view=diff
==============================================================================
---
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
(original)
+++
mahout/trunk/core/src/test/java/org/apache/mahout/clustering/TestClusterClassifier.java
Sun Feb 12 18:44:23 2012
@@ -20,14 +20,9 @@ package org.apache.mahout.clustering;
import java.io.IOException;
import java.util.List;
-import com.google.common.collect.Lists;
-import com.google.common.io.Closeables;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
-import org.apache.hadoop.io.SequenceFile;
-import org.apache.hadoop.io.Text;
-import org.apache.hadoop.io.Writable;
import org.apache.mahout.clustering.canopy.Canopy;
import org.apache.mahout.clustering.dirichlet.models.GaussianCluster;
import org.apache.mahout.clustering.fuzzykmeans.SoftCluster;
@@ -41,31 +36,28 @@ import org.apache.mahout.math.Vector;
import org.apache.mahout.math.VectorWritable;
import org.junit.Test;
+import com.google.common.collect.Lists;
+
public final class TestClusterClassifier extends MahoutTestCase {
-
+
private static ClusterClassifier newDMClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new ManhattanDistanceMeasure();
- models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0,
- measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(1), 0,
measure));
models.add(new DistanceMeasureCluster(new DenseVector(2), 1, measure));
- models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2,
- measure));
+ models.add(new DistanceMeasureCluster(new DenseVector(2).assign(-1), 2,
measure));
return new ClusterClassifier(models);
}
-
+
private static ClusterClassifier newClusterClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new ManhattanDistanceMeasure();
- models.add(new org.apache.mahout.clustering.kmeans.Cluster(new DenseVector(
- 2).assign(1), 0, measure));
- models.add(new org.apache.mahout.clustering.kmeans.Cluster(new DenseVector(
- 2), 1, measure));
- models.add(new org.apache.mahout.clustering.kmeans.Cluster(new DenseVector(
- 2).assign(-1), 2, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Cluster(new
DenseVector(2).assign(1), 0, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Cluster(new
DenseVector(2), 1, measure));
+ models.add(new org.apache.mahout.clustering.kmeans.Cluster(new
DenseVector(2).assign(-1), 2, measure));
return new ClusterClassifier(models);
}
-
+
private static ClusterClassifier newSoftClusterClassifier() {
List<Cluster> models = Lists.newArrayList();
DistanceMeasure measure = new ManhattanDistanceMeasure();
@@ -74,66 +66,30 @@ public final class TestClusterClassifier
models.add(new SoftCluster(new DenseVector(2).assign(-1), 2, measure));
return new ClusterClassifier(models);
}
-
+
private static ClusterClassifier newGaussianClassifier() {
List<Cluster> models = Lists.newArrayList();
- models.add(new GaussianCluster(new DenseVector(2).assign(1),
- new DenseVector(2).assign(1), 0));
- models.add(new GaussianCluster(new DenseVector(2), new DenseVector(2)
- .assign(1), 1));
- models.add(new GaussianCluster(new DenseVector(2).assign(-1),
- new DenseVector(2).assign(1), 2));
+ models.add(new GaussianCluster(new DenseVector(2).assign(1), new
DenseVector(2).assign(1), 0));
+ models.add(new GaussianCluster(new DenseVector(2), new
DenseVector(2).assign(1), 1));
+ models.add(new GaussianCluster(new DenseVector(2).assign(-1), new
DenseVector(2).assign(1), 2));
return new ClusterClassifier(models);
}
-
- private ClusterClassifier writeAndRead(ClusterClassifier classifier)
- throws IOException {
- Configuration config = new Configuration();
+
+ private ClusterClassifier writeAndRead(ClusterClassifier classifier) throws
IOException {
Path path = new Path(getTestTempDirPath(), "output");
- FileSystem fs = FileSystem.get(path.toUri(), config);
- writeClassifier(classifier, config, path, fs);
- return readClassifier(config, path, fs);
- }
-
- private static void writeClassifier(ClusterClassifier classifier,
- Configuration config,
- Path path,
- FileSystem fs) throws IOException {
- SequenceFile.Writer writer = new SequenceFile.Writer(fs, config, path,
- Text.class, ClusterClassifier.class);
- Writable key = new Text("test");
- try {
- writer.append(key, classifier);
- } finally {
- Closeables.closeQuietly(writer);
- }
- }
-
- private static ClusterClassifier readClassifier(Configuration config,
- Path path,
- FileSystem fs) throws
IOException {
- SequenceFile.Reader reader = new SequenceFile.Reader(fs, path, config);
- Writable key = new Text();
- ClusterClassifier classifierOut = new ClusterClassifier();
- try {
- reader.next(key, classifierOut);
- } finally {
- Closeables.closeQuietly(reader);
- }
- return classifierOut;
+ ClusterIterator.writeClassifier(classifier, path);
+ return ClusterIterator.readClassifier(path);
}
-
+
@Test
public void testDMClusterClassification() {
ClusterClassifier classifier = newDMClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.200, 0.600, 0.200]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.200, 0.600, 0.200]",
AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.493, 0.296, 0.211]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.493, 0.296, 0.211]",
AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testCanopyClassification() {
List<Cluster> models = Lists.newArrayList();
@@ -143,24 +99,20 @@ public final class TestClusterClassifier
models.add(new Canopy(new DenseVector(2).assign(-1), 2, measure));
ClusterClassifier classifier = new ClusterClassifier(models);
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.200, 0.600, 0.200]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.200, 0.600, 0.200]",
AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.493, 0.296, 0.211]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.493, 0.296, 0.211]",
AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testClusterClassification() {
ClusterClassifier classifier = newClusterClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.200, 0.600, 0.200]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.200, 0.600, 0.200]",
AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.493, 0.296, 0.211]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.493, 0.296, 0.211]",
AbstractCluster.formatVector(pdf, null));
}
-
+
@Test(expected = UnsupportedOperationException.class)
public void testMSCanopyClassification() {
List<Cluster> models = Lists.newArrayList();
@@ -171,73 +123,64 @@ public final class TestClusterClassifier
ClusterClassifier classifier = new ClusterClassifier(models);
classifier.classify(new DenseVector(2));
}
-
+
@Test
public void testSoftClusterClassification() {
ClusterClassifier classifier = newSoftClusterClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.000, 1.000, 0.000]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.000, 1.000, 0.000]",
AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.735, 0.184, 0.082]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.735, 0.184, 0.082]",
AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testGaussianClusterClassification() {
ClusterClassifier classifier = newGaussianClassifier();
Vector pdf = classifier.classify(new DenseVector(2));
- assertEquals("[0,0]", "[0.212, 0.576, 0.212]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[0,0]", "[0.212, 0.576, 0.212]",
AbstractCluster.formatVector(pdf, null));
pdf = classifier.classify(new DenseVector(2).assign(2));
- assertEquals("[2,2]", "[0.952, 0.047, 0.000]",
- AbstractCluster.formatVector(pdf, null));
+ assertEquals("[2,2]", "[0.952, 0.047, 0.000]",
AbstractCluster.formatVector(pdf, null));
}
-
+
@Test
public void testDMClassifierSerialization() throws Exception {
ClusterClassifier classifier = newDMClassifier();
ClusterClassifier classifierOut = writeAndRead(classifier);
- assertEquals(classifier.getModels().size(), classifierOut.getModels()
- .size());
- assertEquals(classifier.getModels().get(0).getClass().getName(),
- classifierOut.getModels().get(0).getClass().getName());
+ assertEquals(classifier.getModels().size(),
classifierOut.getModels().size());
+ assertEquals(classifier.getModels().get(0).getClass().getName(),
classifierOut.getModels().get(0).getClass()
+ .getName());
}
-
+
@Test
public void testClusterClassifierSerialization() throws Exception {
ClusterClassifier classifier = newClusterClassifier();
ClusterClassifier classifierOut = writeAndRead(classifier);
- assertEquals(classifier.getModels().size(), classifierOut.getModels()
- .size());
- assertEquals(classifier.getModels().get(0).getClass().getName(),
- classifierOut.getModels().get(0).getClass().getName());
+ assertEquals(classifier.getModels().size(),
classifierOut.getModels().size());
+ assertEquals(classifier.getModels().get(0).getClass().getName(),
classifierOut.getModels().get(0).getClass()
+ .getName());
}
-
+
@Test
public void testSoftClusterClassifierSerialization() throws Exception {
ClusterClassifier classifier = newSoftClusterClassifier();
ClusterClassifier classifierOut = writeAndRead(classifier);
- assertEquals(classifier.getModels().size(), classifierOut.getModels()
- .size());
- assertEquals(classifier.getModels().get(0).getClass().getName(),
- classifierOut.getModels().get(0).getClass().getName());
+ assertEquals(classifier.getModels().size(),
classifierOut.getModels().size());
+ assertEquals(classifier.getModels().get(0).getClass().getName(),
classifierOut.getModels().get(0).getClass()
+ .getName());
}
-
+
@Test
public void testGaussianClassifierSerialization() throws Exception {
ClusterClassifier classifier = newGaussianClassifier();
ClusterClassifier classifierOut = writeAndRead(classifier);
- assertEquals(classifier.getModels().size(), classifierOut.getModels()
- .size());
- assertEquals(classifier.getModels().get(0).getClass().getName(),
- classifierOut.getModels().get(0).getClass().getName());
+ assertEquals(classifier.getModels().size(),
classifierOut.getModels().size());
+ assertEquals(classifier.getModels().get(0).getClass().getName(),
classifierOut.getModels().get(0).getClass()
+ .getName());
}
-
+
@Test
public void testClusterIteratorKMeans() {
- List<Vector> data = TestKmeansClustering
- .getPoints(TestKmeansClustering.REFERENCE);
+ List<Vector> data =
TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
ClusteringPolicy policy = new KMeansClusteringPolicy();
ClusterClassifier prior = newClusterClassifier();
ClusterIterator iterator = new ClusterIterator(policy);
@@ -247,11 +190,10 @@ public final class TestClusterClassifier
System.out.println(cluster.asFormatString(null));
}
}
-
+
@Test
public void testClusterIteratorDirichlet() {
- List<Vector> data = TestKmeansClustering
- .getPoints(TestKmeansClustering.REFERENCE);
+ List<Vector> data =
TestKmeansClustering.getPoints(TestKmeansClustering.REFERENCE);
ClusteringPolicy policy = new DirichletClusteringPolicy(3, 1);
ClusterClassifier prior = newClusterClassifier();
ClusterIterator iterator = new ClusterIterator(policy);
@@ -261,7 +203,7 @@ public final class TestClusterClassifier
System.out.println(cluster.asFormatString(null));
}
}
-
+
@Test
public void testSeqFileClusterIteratorKMeans() throws IOException {
Path pointsPath = getTestTempDirPath("points");
@@ -269,13 +211,11 @@ public final class TestClusterClassifier
Path outPath = getTestTempDirPath("output");
Configuration conf = new Configuration();
FileSystem fs = FileSystem.get(pointsPath.toUri(), conf);
- List<VectorWritable> points = TestKmeansClustering
- .getPointsWritable(TestKmeansClustering.REFERENCE);
- ClusteringTestUtils.writePointsToFile(points,
- new Path(pointsPath, "file1"), fs, conf);
+ List<VectorWritable> points =
TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath,
"file1"), fs, conf);
Path path = new Path(priorPath, "priorClassifier");
ClusterClassifier prior = newClusterClassifier();
- writeClassifier(prior, conf, path, fs);
+ ClusterIterator.writeClassifier(prior, path);
assertEquals(3, prior.getModels().size());
System.out.println("Prior");
for (Cluster cluster : prior.getModels()) {
@@ -284,16 +224,47 @@ public final class TestClusterClassifier
ClusteringPolicy policy = new KMeansClusteringPolicy();
ClusterIterator iterator = new ClusterIterator(policy);
iterator.iterateSeq(pointsPath, path, outPath, 5);
-
+
+ for (int i = 1; i <= 5; i++) {
+ System.out.println("Classifier-" + i);
+ ClusterClassifier posterior = ClusterIterator.readClassifier(new
Path(outPath, "classifier-" + i));
+ assertEquals(3, posterior.getModels().size());
+ for (Cluster cluster : posterior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+
+ }
+ }
+
+ @Test
+ public void testMRFileClusterIteratorKMeans() throws IOException,
InterruptedException, ClassNotFoundException {
+ Path pointsPath = getTestTempDirPath("points");
+ Path priorPath = getTestTempDirPath("prior");
+ Path outPath = getTestTempDirPath("output");
+ Configuration conf = new Configuration();
+ FileSystem fs = FileSystem.get(pointsPath.toUri(), conf);
+ List<VectorWritable> points =
TestKmeansClustering.getPointsWritable(TestKmeansClustering.REFERENCE);
+ ClusteringTestUtils.writePointsToFile(points, new Path(pointsPath,
"file1"), fs, conf);
+ Path path = new Path(priorPath, "priorClassifier");
+ ClusterClassifier prior = newClusterClassifier();
+ ClusterIterator.writeClassifier(prior, path);
+ assertEquals(3, prior.getModels().size());
+ System.out.println("Prior");
+ for (Cluster cluster : prior.getModels()) {
+ System.out.println(cluster.asFormatString(null));
+ }
+ ClusteringPolicy policy = new KMeansClusteringPolicy();
+ ClusterIterator iterator = new ClusterIterator(policy);
+ iterator.iterateMR(pointsPath, path, outPath, 5);
+
for (int i = 1; i <= 5; i++) {
System.out.println("Classifier-" + i);
- ClusterClassifier posterior = readClassifier(conf, new Path(outPath,
- "classifier-" + i), fs);
+ ClusterClassifier posterior = ClusterIterator.readClassifier(new
Path(outPath, "clusters-" + i));
assertEquals(3, posterior.getModels().size());
for (Cluster cluster : posterior.getModels()) {
System.out.println(cluster.asFormatString(null));
}
-
+
}
}
}